Updated EvoNorm implementations with some experimentation. Add FilterResponseNorm. Updated RegnetZ and ResNetV2 model defs for trials.

pull/1014/head
Ross Wightman 3 years ago
parent 55adfbeb8d
commit 78912b6375

@ -35,7 +35,8 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, named_apply from .helpers import build_model_with_cfg, named_apply
from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple, EvoNormSample2d create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple, EvoNorm2dS0, EvoNorm2dS0a,\
EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, FilterResponseNormAct2d, FilterResponseNormTlu2d
from .registry import register_model from .registry import register_model
__all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block'] __all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block']
@ -152,6 +153,12 @@ default_cfgs = {
'regnetz_e8': _cfgr( 'regnetz_e8': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_e8_bh-aace8e6e.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_e8_bh-aace8e6e.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=1.0), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=1.0),
'regnetz_b16_evos': _cfgr(
url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv',
crop_pct=0.94),
'regnetz_d8_evob': _cfgr( 'regnetz_d8_evob': _cfgr(
url='', url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95),
@ -597,6 +604,23 @@ model_cfgs = dict(
), ),
# experimental EvoNorm configs # experimental EvoNorm configs
regnetz_b16_evos=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3),
ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3),
ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=3),
ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=3),
),
stem_chs=32,
stem_pool='',
downsample='',
num_features=1536,
act_layer='silu',
norm_layer=partial(EvoNorm2dS0a, group_size=16),
attn_layer='se',
attn_kwargs=dict(rd_ratio=0.25),
block_kwargs=dict(bottle_in=True, linear_out=True),
),
regnetz_d8_evob=ByoModelCfg( regnetz_d8_evob=ByoModelCfg(
blocks=( blocks=(
ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4), ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4),
@ -610,7 +634,7 @@ model_cfgs = dict(
downsample='', downsample='',
num_features=1792, num_features=1792,
act_layer='silu', act_layer='silu',
norm_layer='evonormbatch', norm_layer='evonormb0',
attn_layer='se', attn_layer='se',
attn_kwargs=dict(rd_ratio=0.25), attn_kwargs=dict(rd_ratio=0.25),
block_kwargs=dict(bottle_in=True, linear_out=True), block_kwargs=dict(bottle_in=True, linear_out=True),
@ -628,7 +652,7 @@ model_cfgs = dict(
downsample='', downsample='',
num_features=1792, num_features=1792,
act_layer='silu', act_layer='silu',
norm_layer=partial(EvoNormSample2d, groups=32), norm_layer=partial(EvoNorm2dS0a, group_size=16),
attn_layer='se', attn_layer='se',
attn_kwargs=dict(rd_ratio=0.25), attn_kwargs=dict(rd_ratio=0.25),
block_kwargs=dict(bottle_in=True, linear_out=True), block_kwargs=dict(bottle_in=True, linear_out=True),
@ -856,6 +880,13 @@ def regnetz_e8(pretrained=False, **kwargs):
return _create_byobnet('regnetz_e8', pretrained=pretrained, **kwargs) return _create_byobnet('regnetz_e8', pretrained=pretrained, **kwargs)
@register_model
def regnetz_b16_evos(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('regnetz_b16_evos', pretrained=pretrained, **kwargs)
@register_model @register_model
def regnetz_d8_evob(pretrained=False, **kwargs): def regnetz_d8_evob(pretrained=False, **kwargs):
""" """

@ -14,6 +14,8 @@ except ImportError:
# Layers we went to treat as leaf modules # Layers we went to treat as leaf modules
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath
from .layers import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2
from .layers import EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
from .layers.non_local_attn import BilinearAttnTransform from .layers.non_local_attn import BilinearAttnTransform
from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
@ -27,6 +29,9 @@ _leaf_modules = {
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
DropPath, # reason: TypeError: rand recieved Proxy in `size` argument DropPath, # reason: TypeError: rand recieved Proxy in `size` argument
EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2, # to(dtype) use that causes tracing failure (on scripted models only?)
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a,
} }
try: try:

@ -14,7 +14,9 @@ from .create_conv2d import create_conv2d
from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
from .evo_norm import EvoNormBatch2d, EvoNormSample2d from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
from .gather_excite import GatherExcite from .gather_excite import GatherExcite
from .global_context import GlobalContext from .global_context import GlobalContext
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible

@ -116,9 +116,6 @@ def get_act_fn(name: Union[Callable, str] = 'relu'):
# custom autograd, then fallback # custom autograd, then fallback
if name in _ACT_FN_ME: if name in _ACT_FN_ME:
return _ACT_FN_ME[name] return _ACT_FN_ME[name]
if is_exportable() and name in ('silu', 'swish'):
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
return swish
if not (is_no_jit() or is_exportable()): if not (is_no_jit() or is_exportable()):
if name in _ACT_FN_JIT: if name in _ACT_FN_JIT:
return _ACT_FN_JIT[name] return _ACT_FN_JIT[name]
@ -132,14 +129,12 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
""" """
if not name: if not name:
return None return None
if isinstance(name, type): if not isinstance(name, str):
# callable, module, etc
return name return name
if not (is_no_jit() or is_exportable() or is_scriptable()): if not (is_no_jit() or is_exportable() or is_scriptable()):
if name in _ACT_LAYER_ME: if name in _ACT_LAYER_ME:
return _ACT_LAYER_ME[name] return _ACT_LAYER_ME[name]
if is_exportable() and name in ('silu', 'swish'):
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
return Swish
if not (is_no_jit() or is_exportable()): if not (is_no_jit() or is_exportable()):
if name in _ACT_LAYER_JIT: if name in _ACT_LAYER_JIT:
return _ACT_LAYER_JIT[name] return _ACT_LAYER_JIT[name]

@ -9,36 +9,42 @@ Hacked together by / Copyright 2020 Ross Wightman
import types import types
import functools import functools
import torch from .evo_norm import *
import torch.nn as nn from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .norm_act import BatchNormAct2d, GroupNormAct from .norm_act import BatchNormAct2d, GroupNormAct
from .inplace_abn import InplaceAbn from .inplace_abn import InplaceAbn
_NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn} _NORM_ACT_MAP = dict(
_NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} # requires act_layer arg to define act type batchnorm=BatchNormAct2d,
groupnorm=GroupNormAct,
evonormb0=EvoNorm2dB0,
evonormb1=EvoNorm2dB1,
evonormb2=EvoNorm2dB2,
evonorms0=EvoNorm2dS0,
evonorms0a=EvoNorm2dS0a,
evonorms1=EvoNorm2dS1,
evonorms1a=EvoNorm2dS1a,
evonorms2=EvoNorm2dS2,
evonorms2a=EvoNorm2dS2a,
frn=FilterResponseNormAct2d,
frntlu=FilterResponseNormTlu2d,
inplaceabn=InplaceAbn,
iabn=InplaceAbn,
)
_NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()}
# has act_layer arg to define act type
_NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, FilterResponseNormAct2d, InplaceAbn}
def get_norm_act_layer(layer_class): def get_norm_act_layer(layer_name):
layer_class = layer_class.replace('_', '').lower() layer_name = layer_name.replace('_', '').lower().split('-')[0]
if layer_class.startswith("batchnorm"): layer = _NORM_ACT_MAP.get(layer_name, None)
layer = BatchNormAct2d assert layer is not None, "Invalid norm_act layer (%s)" % layer_name
elif layer_class.startswith("groupnorm"):
layer = GroupNormAct
elif layer_class == "evonormbatch":
layer = EvoNormBatch2d
elif layer_class == "evonormsample":
layer = EvoNormSample2d
elif layer_class == "iabn" or layer_class == "inplaceabn":
layer = InplaceAbn
else:
assert False, "Invalid norm_act layer (%s)" % layer_class
return layer return layer
def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs): def create_norm_act(layer_name, num_features, apply_act=True, jit=False, **kwargs):
layer_parts = layer_type.split('-') # e.g. batchnorm-leaky_relu layer_parts = layer_name.split('-') # e.g. batchnorm-leaky_relu
assert len(layer_parts) in (1, 2) assert len(layer_parts) in (1, 2)
layer = get_norm_act_layer(layer_parts[0]) layer = get_norm_act_layer(layer_parts[0])
#activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection? #activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection?

@ -1,81 +1,332 @@
"""EvoNormB0 (Batched) and EvoNormS0 (Sample) in PyTorch """ EvoNorm in PyTorch
Based on `Evolving Normalization-Activation Layers` - https://arxiv.org/abs/2004.02967
@inproceedings{NEURIPS2020,
author = {Liu, Hanxiao and Brock, Andy and Simonyan, Karen and Le, Quoc},
booktitle = {Advances in Neural Information Processing Systems},
editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin},
pages = {13539--13550},
publisher = {Curran Associates, Inc.},
title = {Evolving Normalization-Activation Layers},
url = {https://proceedings.neurips.cc/paper/2020/file/9d4c03631b8b0c85ae08bf05eda37d0f-Paper.pdf},
volume = {33},
year = {2020}
}
An attempt at getting decent performing EvoNorms running in PyTorch. An attempt at getting decent performing EvoNorms running in PyTorch.
While currently faster than other impl, still quite a ways off the built-in BN While faster than other PyTorch impl, still quite a ways off the built-in BatchNorm
in terms of memory usage and throughput (roughly 5x mem, 1/2 - 1/3x speed). in terms of memory usage and throughput on GPUs.
Still very much a WIP, fiddling with buffer usage, in-place/jit optimizations, and layouts. I'm testing these modules on TPU w/ PyTorch XLA. Promising start but
currently working around some issues with builtin torch/tensor.var/std. Unlike
GPU, similar train speeds for EvoNormS variants and BatchNorm.
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from .create_act import create_act_layer
from .trace_utils import _assert from .trace_utils import _assert
class EvoNormBatch2d(nn.Module): def instance_std(x, eps: float = 1e-5):
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None): rms = x.float().var(dim=(2, 3), unbiased=False, keepdim=True).add(eps).sqrt().to(x.dtype)
super(EvoNormBatch2d, self).__init__() return rms.expand(x.shape)
def instance_rms(x, eps: float = 1e-5):
rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(dtype=x.dtype)
return rms.expand(x.shape)
def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False):
B, C, H, W = x.shape
x_dtype = x.dtype
_assert(C % groups == 0, '')
# x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues
# std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt()
x = x.reshape(B, groups, C // groups, H, W)
std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt()
return std.expand(x.shape).reshape(B, C, H, W).to(x_dtype)
def group_std_tpu(x, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False):
# This is a workaround for some stability / odd behaviour of .var and .std
# running on PyTorch XLA w/ TPUs. These manual var impl are producing much better results
B, C, H, W = x.shape
_assert(C % groups == 0, '')
x_dtype = x.dtype
x = x.float().reshape(B, groups, C // groups, H, W)
xm = x.mean(dim=(2, 3, 4), keepdim=True)
if diff_sqm:
# difference of squared mean and mean squared, faster on TPU
var = (x.square().mean(dim=(2, 3, 4), keepdim=True) - xm.square()).clamp(0)
else:
var = (x - xm).square().mean(dim=(2, 3, 4), keepdim=True)
return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W).to(x_dtype)
# group_std = group_std_tpu # temporary, for TPU / PT XLA
def group_rms(x, groups: int = 32, eps: float = 1e-5):
B, C, H, W = x.shape
_assert(C % groups == 0, '')
x_dtype = x.dtype
x = x.reshape(B, groups, C // groups, H, W)
sqm = x.square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype)
return sqm.expand(x.shape).reshape(B, C, H, W)
class EvoNorm2dB0(nn.Module):
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_):
super().__init__()
self.apply_act = apply_act # apply activation (non-linearity) self.apply_act = apply_act # apply activation (non-linearity)
self.momentum = momentum self.momentum = momentum
self.eps = eps self.eps = eps
self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True) self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True) self.bias = nn.Parameter(torch.zeros(num_features))
self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None
self.register_buffer('running_var', torch.ones(num_features)) self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
nn.init.ones_(self.weight) nn.init.ones_(self.weight)
nn.init.zeros_(self.bias) nn.init.zeros_(self.bias)
if self.apply_act: if self.v is not None:
nn.init.ones_(self.v) nn.init.ones_(self.v)
def forward(self, x): def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input') _assert(x.dim() == 4, 'expected 4D input')
x_type = x.dtype x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
if self.v is not None: if self.v is not None:
running_var = self.running_var.view(1, -1, 1, 1)
if self.training: if self.training:
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) var = x.float().var(dim=(0, 2, 3), unbiased=False)
n = x.numel() / x.shape[1] n = x.numel() / x.shape[1]
running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum) self.running_var.copy_(
self.running_var.copy_(running_var.view(self.running_var.shape)) self.running_var * (1 - self.momentum) +
var.detach() * self.momentum * (n / (n - 1)))
else: else:
var = running_var var = self.running_var
v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1) left = var.add(self.eps).sqrt_().to(x_dtype).view(v_shape).expand_as(x)
d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type) v = self.v.to(x_dtype).view(v_shape)
d = d.max((var + self.eps).sqrt().to(dtype=x_type)) right = x * v + instance_std(x, self.eps)
x = x / d x = x / left.max(right)
return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) return x * self.weight.to(x_dtype).view(v_shape) + self.bias.to(x_dtype).view(v_shape)
class EvoNormSample2d(nn.Module): class EvoNorm2dB1(nn.Module):
def __init__(self, num_features, apply_act=True, groups=32, eps=1e-5, drop_block=None): def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_):
super(EvoNormSample2d, self).__init__() super().__init__()
self.apply_act = apply_act # apply activation (non-linearity) self.apply_act = apply_act # apply activation (non-linearity)
self.groups = groups self.momentum = momentum
self.eps = eps
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
if self.apply_act:
if self.training:
var = x.float().var(dim=(0, 2, 3), unbiased=False)
n = x.numel() / x.shape[1]
self.running_var.copy_(
self.running_var * (1 - self.momentum) +
var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1)))
else:
var = self.running_var
var = var.to(dtype=x_dtype).view(v_shape)
left = var.add(self.eps).sqrt_()
right = (x + 1) * instance_rms(x, self.eps)
x = x / left.max(right)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
class EvoNorm2dB2(nn.Module):
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_):
super().__init__()
self.apply_act = apply_act # apply activation (non-linearity)
self.momentum = momentum
self.eps = eps self.eps = eps
self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True) self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True) self.bias = nn.Parameter(torch.zeros(num_features))
self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
nn.init.ones_(self.weight) nn.init.ones_(self.weight)
nn.init.zeros_(self.bias) nn.init.zeros_(self.bias)
def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
if self.apply_act: if self.apply_act:
if self.training:
var = x.float().var(dim=(0, 2, 3), unbiased=False)
n = x.numel() / x.shape[1]
self.running_var.copy_(
self.running_var * (1 - self.momentum) +
var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1)))
else:
var = self.running_var
var = var.to(dtype=x_dtype).view(v_shape)
left = var.add(self.eps).sqrt_()
right = instance_rms(x, self.eps) - x
x = x / left.max(right)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
class EvoNorm2dS0(nn.Module):
def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_):
super().__init__()
self.apply_act = apply_act # apply activation (non-linearity)
if group_size:
assert num_features % group_size == 0
self.groups = num_features // group_size
else:
self.groups = groups
self.eps = eps
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.v is not None:
nn.init.ones_(self.v) nn.init.ones_(self.v)
def forward(self, x): def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input') _assert(x.dim() == 4, 'expected 4D input')
B, C, H, W = x.shape x_dtype = x.dtype
_assert(C % self.groups == 0, '') v_shape = (1, -1, 1, 1)
if self.v is not None:
v = self.v.view(v_shape).to(dtype=x_dtype)
x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
class EvoNorm2dS0a(EvoNorm2dS0):
def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_):
super().__init__(
num_features, groups=groups, group_size=group_size, apply_act=apply_act, eps=eps)
def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
d = group_std(x, self.groups, self.eps)
if self.v is not None: if self.v is not None:
n = x * (x * self.v.view(1, -1, 1, 1)).sigmoid() v = self.v.view(v_shape).to(dtype=x_dtype)
x = x.reshape(B, self.groups, -1) x = x * (x * v).sigmoid_()
x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt() x = x / d
x = x.reshape(B, C, H, W) return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
class EvoNorm2dS1(nn.Module):
def __init__(
self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
super().__init__()
self.apply_act = apply_act # apply activation (non-linearity)
if act_layer is not None and apply_act:
self.act = create_act_layer(act_layer)
else:
self.act = nn.Identity()
if group_size:
assert num_features % group_size == 0
self.groups = num_features // group_size
else:
self.groups = groups
self.eps = eps
self.pre_act_norm = False
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
if self.apply_act:
x = self.act(x) / group_std(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
class EvoNorm2dS1a(EvoNorm2dS1):
def __init__(
self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
super().__init__(
num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps)
def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
x = self.act(x) / group_std(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
class EvoNorm2dS2(nn.Module):
def __init__(
self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
super().__init__()
self.apply_act = apply_act # apply activation (non-linearity)
if act_layer is not None and apply_act:
self.act = create_act_layer(act_layer)
else:
self.act = nn.Identity()
if group_size:
assert num_features % group_size == 0
self.groups = num_features // group_size
else:
self.groups = groups
self.eps = eps
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
if self.apply_act:
x = self.act(x) / group_rms(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
class EvoNorm2dS2a(EvoNorm2dS2):
def __init__(
self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
super().__init__(
num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps)
def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
x = self.act(x) / group_rms(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)

@ -0,0 +1,68 @@
""" Filter Response Norm in PyTorch
Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737
Hacked together by / Copyright 2021 Ross Wightman
"""
import torch
import torch.nn as nn
from .create_act import create_act_layer
from .trace_utils import _assert
def inv_instance_rms(x, eps: float = 1e-5):
rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype)
return rms.expand(x.shape)
class FilterResponseNormTlu2d(nn.Module):
def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_):
super(FilterResponseNormTlu2d, self).__init__()
self.apply_act = apply_act # apply activation (non-linearity)
self.rms = rms
self.eps = eps
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.tau is not None:
nn.init.zeros_(self.tau)
def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
x = x * inv_instance_rms(x, self.eps)
x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x
class FilterResponseNormAct2d(nn.Module):
def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_):
super(FilterResponseNormAct2d, self).__init__()
if act_layer is not None and apply_act:
self.act = create_act_layer(act_layer, inplace=inplace)
else:
self.act = nn.Identity()
self.rms = rms
self.eps = eps
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
x = x * inv_instance_rms(x, self.eps)
x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return self.act(x)

@ -38,7 +38,8 @@ from functools import partial
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
from .registry import register_model from .registry import register_model
from .layers import GroupNormAct, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d,\ from .layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0,\
EvoNorm2dS1, EvoNorm2dS2, FilterResponseNormTlu2d, FilterResponseNormAct2d,\
ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d
@ -125,7 +126,11 @@ default_cfgs = {
interpolation='bicubic', first_conv='stem.conv1'), interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50d_evob': _cfg( 'resnetv2_50d_evob': _cfg(
interpolation='bicubic', first_conv='stem.conv1'), interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50d_evos': _cfg( 'resnetv2_50d_evos0': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50d_evos1': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50d_frn': _cfg(
interpolation='bicubic', first_conv='stem.conv1'), interpolation='bicubic', first_conv='stem.conv1'),
} }
@ -660,13 +665,29 @@ def resnetv2_50d_gn(pretrained=False, **kwargs):
def resnetv2_50d_evob(pretrained=False, **kwargs): def resnetv2_50d_evob(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2(
'resnetv2_50d_evob', pretrained=pretrained, 'resnetv2_50d_evob', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormBatch2d, layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dB0,
stem_type='deep', avg_down=True, zero_init_last=True, **kwargs)
@register_model
def resnetv2_50d_evos0(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_evos0', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dS0,
stem_type='deep', avg_down=True, **kwargs)
@register_model
def resnetv2_50d_evos1(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_evos1', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=partial(EvoNorm2dS1, group_size=16),
stem_type='deep', avg_down=True, **kwargs) stem_type='deep', avg_down=True, **kwargs)
@register_model @register_model
def resnetv2_50d_evos(pretrained=False, **kwargs): def resnetv2_50d_frn(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2(
'resnetv2_50d_evos', pretrained=pretrained, 'resnetv2_50d_frn', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormSample2d, layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=FilterResponseNormTlu2d,
stem_type='deep', avg_down=True, **kwargs) stem_type='deep', avg_down=True, **kwargs)

@ -395,7 +395,7 @@ def eca_vovnet39b(pretrained=False, **kwargs):
@register_model @register_model
def ese_vovnet39b_evos(pretrained=False, **kwargs): def ese_vovnet39b_evos(pretrained=False, **kwargs):
def norm_act_fn(num_features, **nkwargs): def norm_act_fn(num_features, **nkwargs):
return create_norm_act('EvoNormSample', num_features, jit=False, **nkwargs) return create_norm_act('evonorms0', num_features, jit=False, **nkwargs)
return _create_vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs) return _create_vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs)

Loading…
Cancel
Save