Monster commit, activation refactor, VoVNet, norm_act improvements, more

* refactor activations into basic PyTorch, jit scripted, and memory efficient custom auto
* implement hard-mish, better grad for hard-swish
* add initial VovNet V1/V2 impl, fix #151
* VovNet and DenseNet first models to use NormAct layers (support BatchNormAct2d, EvoNorm, InplaceIABN)
* Wrap IABN for any models that use it
* make more models torchscript compatible (DPN, PNasNet, Res2Net, SelecSLS) and add tests
pull/155/head
Ross Wightman 5 years ago
parent ff94ffce61
commit eb7653614f

@ -4,7 +4,7 @@ import platform
import os
import fnmatch
from timm import list_models, create_model
from timm import list_models, create_model, set_scriptable
if 'GITHUB_ACTIONS' in os.environ and 'Linux' in platform.system():
@ -53,6 +53,8 @@ def test_model_backward(model_name, batch_size):
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
outputs.mean().backward()
for n, x in model.named_parameters():
assert x.grad is not None, f'No gradient for {n}'
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
assert outputs.shape[-1] == 42
@ -83,3 +85,25 @@ def test_model_default_cfgs(model_name, batch_size):
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
assert any([k.startswith(classifier) for k in state_dict.keys()]), f'{classifier} not in model params'
assert any([k.startswith(first_conv) for k in state_dict.keys()]), f'{first_conv} not in model params'
EXCLUDE_JIT_FILTERS = [
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable
'dla*', 'hrnet*', # hopefully fix at some point
]
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_torchscript(model_name, batch_size):
"""Run a single forward pass with each model"""
with set_scriptable(True):
model = create_model(model_name, pretrained=False)
model.eval()
input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already...
model = torch.jit.script(model)
outputs = model(torch.randn((batch_size, *input_size)))
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'

@ -1,2 +1,3 @@
from .version import __version__
from .models import create_model, list_models, is_model, list_modules, model_entrypoint
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
is_scriptable, is_exportable, set_scriptable, set_exportable

@ -20,9 +20,11 @@ from .sknet import *
from .tresnet import *
from .resnest import *
from .regnet import *
from .vovnet import *
from .registry import *
from .factory import create_model
from .helpers import load_checkpoint, resume_checkpoint
from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit

@ -41,13 +41,13 @@ default_cfgs = {
class DenseLayer(nn.Module):
def __init__(self, num_input_features, growth_rate, bn_size, norm_act_layer=BatchNormAct2d,
def __init__(self, num_input_features, growth_rate, bn_size, norm_layer=BatchNormAct2d,
drop_rate=0., memory_efficient=False):
super(DenseLayer, self).__init__()
self.add_module('norm1', norm_act_layer(num_input_features)),
self.add_module('norm1', norm_layer(num_input_features)),
self.add_module('conv1', nn.Conv2d(
num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)),
self.add_module('norm2', norm_act_layer(bn_size * growth_rate)),
self.add_module('norm2', norm_layer(bn_size * growth_rate)),
self.add_module('conv2', nn.Conv2d(
bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)),
self.drop_rate = float(drop_rate)
@ -109,7 +109,7 @@ class DenseLayer(nn.Module):
class DenseBlock(nn.ModuleDict):
_version = 2
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, norm_act_layer=nn.ReLU,
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU,
drop_rate=0., memory_efficient=False):
super(DenseBlock, self).__init__()
for i in range(num_layers):
@ -117,7 +117,7 @@ class DenseBlock(nn.ModuleDict):
num_input_features + i * growth_rate,
growth_rate=growth_rate,
bn_size=bn_size,
norm_act_layer=norm_act_layer,
norm_layer=norm_layer,
drop_rate=drop_rate,
memory_efficient=memory_efficient,
)
@ -132,9 +132,9 @@ class DenseBlock(nn.ModuleDict):
class DenseTransition(nn.Sequential):
def __init__(self, num_input_features, num_output_features, norm_act_layer=nn.BatchNorm2d, aa_layer=None):
def __init__(self, num_input_features, num_output_features, norm_layer=nn.BatchNorm2d, aa_layer=None):
super(DenseTransition, self).__init__()
self.add_module('norm', norm_act_layer(num_input_features))
self.add_module('norm', norm_layer(num_input_features))
self.add_module('conv', nn.Conv2d(
num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
if aa_layer is not None:
@ -160,7 +160,7 @@ class DenseNet(nn.Module):
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), bn_size=4, stem_type='',
num_classes=1000, in_chans=3, global_pool='avg',
norm_act_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False):
norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False):
self.num_classes = num_classes
self.drop_rate = drop_rate
super(DenseNet, self).__init__()
@ -181,17 +181,17 @@ class DenseNet(nn.Module):
stem_chs_2 = num_init_features if 'narrow' in stem_type else 6 * (growth_rate // 4)
self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False)),
('norm0', norm_act_layer(stem_chs_1)),
('norm0', norm_layer(stem_chs_1)),
('conv1', nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False)),
('norm1', norm_act_layer(stem_chs_2)),
('norm1', norm_layer(stem_chs_2)),
('conv2', nn.Conv2d(stem_chs_2, num_init_features, 3, stride=1, padding=1, bias=False)),
('norm2', norm_act_layer(num_init_features)),
('norm2', norm_layer(num_init_features)),
('pool0', stem_pool),
]))
else:
self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
('norm0', norm_act_layer(num_init_features)),
('norm0', norm_layer(num_init_features)),
('pool0', stem_pool),
]))
@ -203,7 +203,7 @@ class DenseNet(nn.Module):
num_input_features=num_features,
bn_size=bn_size,
growth_rate=growth_rate,
norm_act_layer=norm_act_layer,
norm_layer=norm_layer,
drop_rate=drop_rate,
memory_efficient=memory_efficient
)
@ -212,12 +212,12 @@ class DenseNet(nn.Module):
if i != len(block_config) - 1:
trans = DenseTransition(
num_input_features=num_features, num_output_features=num_features // 2,
norm_act_layer=norm_act_layer)
norm_layer=norm_layer)
self.features.add_module('transition%d' % (i + 1), trans)
num_features = num_features // 2
# Final batch norm
self.features.add_module('norm5', norm_act_layer(num_features))
self.features.add_module('norm5', norm_layer(num_features))
# Linear layer
self.num_features = num_features
@ -346,7 +346,7 @@ def densenet121d_evob(pretrained=False, **kwargs):
return create_norm_act('EvoNormBatch', num_features, jit=True, **kwargs)
model = _densenet(
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs)
norm_layer=norm_act_fn, pretrained=pretrained, **kwargs)
return model
@ -359,7 +359,7 @@ def densenet121d_evos(pretrained=False, **kwargs):
return create_norm_act('EvoNormSample', num_features, jit=True, **kwargs)
model = _densenet(
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs)
norm_layer=norm_act_fn, pretrained=pretrained, **kwargs)
return model
@ -372,7 +372,7 @@ def densenet121d_iabn(pretrained=False, **kwargs):
return create_norm_act('iabn', num_features, **kwargs)
model = _densenet(
'densenet121tn', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
norm_act_layer=norm_act_fn, pretrained=pretrained, **kwargs)
norm_layer=norm_act_fn, pretrained=pretrained, **kwargs)
return model

@ -10,6 +10,7 @@ from __future__ import division
from __future__ import print_function
from collections import OrderedDict
from typing import Union, Optional, List, Tuple
import torch
import torch.nn as nn
@ -54,8 +55,19 @@ class CatBnAct(nn.Module):
self.bn = nn.BatchNorm2d(in_chs, eps=0.001)
self.act = activation_fn
@torch.jit._overload_method # noqa: F811
def forward(self, x):
x = torch.cat(x, dim=1) if isinstance(x, tuple) else x
# type: (Tuple[torch.Tensor, torch.Tensor]) -> (torch.Tensor)
pass
@torch.jit._overload_method # noqa: F811
def forward(self, x):
# type: (torch.Tensor) -> (torch.Tensor)
pass
def forward(self, x):
if isinstance(x, tuple):
x = torch.cat(x, dim=1)
return self.act(self.bn(x))
@ -107,6 +119,8 @@ class DualPathBlock(nn.Module):
self.key_stride = 1
self.has_proj = False
self.c1x1_w_s1 = None
self.c1x1_w_s2 = None
if self.has_proj:
# Using different member names here to allow easier parameter key matching for conversion
if self.key_stride == 2:
@ -115,6 +129,7 @@ class DualPathBlock(nn.Module):
else:
self.c1x1_w_s1 = BnActConv2d(
in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=1)
self.c1x1_a = BnActConv2d(in_chs=in_chs, out_chs=num_1x1_a, kernel_size=1, stride=1)
self.c3x3_b = BnActConv2d(
in_chs=num_1x1_a, out_chs=num_3x3_b, kernel_size=3,
@ -125,27 +140,46 @@ class DualPathBlock(nn.Module):
self.c1x1_c2 = nn.Conv2d(num_3x3_b, inc, kernel_size=1, bias=False)
else:
self.c1x1_c = BnActConv2d(in_chs=num_3x3_b, out_chs=num_1x1_c + inc, kernel_size=1, stride=1)
self.c1x1_c1 = None
self.c1x1_c2 = None
@torch.jit._overload_method # noqa: F811
def forward(self, x):
x_in = torch.cat(x, dim=1) if isinstance(x, tuple) else x
if self.has_proj:
if self.key_stride == 2:
x_s = self.c1x1_w_s2(x_in)
else:
x_s = self.c1x1_w_s1(x_in)
x_s1 = x_s[:, :self.num_1x1_c, :, :]
x_s2 = x_s[:, self.num_1x1_c:, :, :]
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
pass
@torch.jit._overload_method # noqa: F811
def forward(self, x):
# type: (torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
pass
def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
if isinstance(x, tuple):
x_in = torch.cat(x, dim=1)
else:
x_in = x
if self.c1x1_w_s1 is None and self.c1x1_w_s2 is None:
# self.has_proj == False, torchscript requires condition on module == None
x_s1 = x[0]
x_s2 = x[1]
else:
# self.has_proj == True
if self.c1x1_w_s1 is not None:
# self.key_stride = 1
x_s = self.c1x1_w_s1(x_in)
else:
# self.key_stride = 2
x_s = self.c1x1_w_s2(x_in)
x_s1 = x_s[:, :self.num_1x1_c, :, :]
x_s2 = x_s[:, self.num_1x1_c:, :, :]
x_in = self.c1x1_a(x_in)
x_in = self.c3x3_b(x_in)
if self.b:
x_in = self.c1x1_c(x_in)
x_in = self.c1x1_c(x_in)
if self.c1x1_c1 is not None:
# self.b == True, using None check for torchscript compat
out1 = self.c1x1_c1(x_in)
out2 = self.c1x1_c2(x_in)
else:
x_in = self.c1x1_c(x_in)
out1 = x_in[:, :self.num_1x1_c, :, :]
out2 = x_in[:, self.num_1x1_c:, :, :]
resid = x_s1 + out1
@ -167,11 +201,9 @@ class DPN(nn.Module):
# conv1
if small:
blocks['conv1_1'] = InputBlock(
num_init_features, in_chans=in_chans, kernel_size=3, padding=1)
blocks['conv1_1'] = InputBlock(num_init_features, in_chans=in_chans, kernel_size=3, padding=1)
else:
blocks['conv1_1'] = InputBlock(
num_init_features, in_chans=in_chans, kernel_size=7, padding=3)
blocks['conv1_1'] = InputBlock(num_init_features, in_chans=in_chans, kernel_size=7, padding=3)
# conv2
bw = 64 * bw_factor

@ -24,11 +24,15 @@ An implementation of EfficienNet that covers variety of related models with effi
Hacked together by Ross Wightman
"""
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .efficientnet_builder import *
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
from .feature_hooks import FeatureHooks
from .helpers import load_pretrained, adapt_model_from_file
from .layers import SelectAdaptivePool2d
from .layers import SelectAdaptivePool2d, create_conv2d
from .registry import register_model
__all__ = ['EfficientNet']
@ -631,7 +635,7 @@ def _gen_mobilenet_v2(
fix_stem=fix_stem_head,
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
act_layer=nn.ReLU6,
act_layer=resolve_act_layer(kwargs, 'relu6'),
**kwargs
)
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
@ -741,7 +745,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
num_features=round_channels(1280, channel_multiplier, 8, None),
stem_size=32,
channel_multiplier=channel_multiplier,
act_layer=Swish,
act_layer=resolve_act_layer(kwargs, 'swish'),
norm_kwargs=resolve_bn_args(kwargs),
variant=variant,
**kwargs,
@ -772,7 +776,7 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0
stem_size=32,
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
act_layer=nn.ReLU,
act_layer=resolve_act_layer(kwargs, 'relu'),
**kwargs,
)
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
@ -802,7 +806,7 @@ def _gen_efficientnet_condconv(
stem_size=32,
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
act_layer=Swish,
act_layer=resolve_act_layer(kwargs, 'swish'),
**kwargs,
)
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
@ -842,7 +846,7 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0
stem_size=32,
fix_stem=True,
channel_multiplier=channel_multiplier,
act_layer=nn.ReLU6,
act_layer=resolve_act_layer(kwargs, 'relu6'),
norm_kwargs=resolve_bn_args(kwargs),
**kwargs,
)

@ -1,9 +1,9 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from .layers.activations import sigmoid
from .layers import create_conv2d, drop_path
from .layers import create_conv2d, drop_path, get_act_layer
from .layers.activations import sigmoid
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
@ -52,6 +52,13 @@ def resolve_se_args(kwargs, in_chs, act_layer=None):
return se_kwargs
def resolve_act_layer(kwargs, default='relu'):
act_layer = kwargs.pop('act_layer', default)
if isinstance(act_layer, str):
act_layer = get_act_layer(act_layer)
return act_layer
def make_divisible(v, divisor=8, min_value=None):
min_value = min_value or divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
@ -213,7 +220,7 @@ class InvertedResidual(nn.Module):
has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_path_rate = drop_path_rate
print(act_layer)
# Point-wise expansion
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
self.bn1 = norm_layer(mid_chs, **norm_kwargs)

@ -1,13 +1,15 @@
import logging
import math
import re
from collections.__init__ import OrderedDict
from collections import OrderedDict
from copy import deepcopy
import torch.nn as nn
from .layers import CondConv2d, get_condconv_initializer
from .layers.activations import HardSwish, Swish
from .efficientnet_blocks import *
from .layers import CondConv2d, get_condconv_initializer
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights"]
def _parse_ksize(ss):
@ -57,13 +59,13 @@ def _decode_block_str(block_str):
key = op[0]
v = op[1:]
if v == 're':
value = nn.ReLU
value = get_act_layer('relu')
elif v == 'r6':
value = nn.ReLU6
value = get_act_layer('relu6')
elif v == 'hs':
value = HardSwish
value = get_act_layer('hard_swish')
elif v == 'sw':
value = Swish
value = get_act_layer('swish')
else:
continue
options[key] = value

@ -1,25 +1,28 @@
from .padding import get_padding
from .pool2d_same import AvgPool2dSame
from .conv2d_same import Conv2dSame
from .conv_bn_act import ConvBnAct
from .mixed_conv2d import MixedConv2d
from .cond_conv2d import CondConv2d, get_condconv_initializer
from .pool2d_same import create_pool2d
from .create_conv2d import create_conv2d
from .create_attn import create_attn
from .selective_kernel import SelectiveKernelConv
from .se import SEModule
from .eca import EcaModule, CecaModule
from .activations import *
from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .anti_aliasing import AntiAliasDownsampleLayer
from .space_to_depth import SpaceToDepthModule
from .blur_pool import BlurPool2d
from .norm_act import BatchNormAct2d
from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, set_exportable, set_scriptable, is_no_jit, set_no_jit
from .conv2d_same import Conv2dSame
from .conv_bn_act import ConvBnAct
from .create_act import create_act_layer, get_act_layer, get_act_fn
from .create_attn import create_attn
from .create_conv2d import create_conv2d
from .create_norm_act import create_norm_act, get_norm_act_layer
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .eca import EcaModule, CecaModule
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .create_norm_act import create_norm_act
from .inplace_abn import InplaceAbn
from .mixed_conv2d import MixedConv2d
from .norm_act import BatchNormAct2d
from .padding import get_padding
from .pool2d_same import AvgPool2dSame, create_pool2d
from .se import SEModule
from .selective_kernel import SelectiveKernelConv
from .separable_conv import SeparableConv2d, SeparableConvBnAct
from .space_to_depth import SpaceToDepthModule
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .weight_init import trunc_normal_

@ -6,85 +6,15 @@ easily be swapped. All have an `inplace` arg even if not used.
Hacked together by Ross Wightman
"""
import torch
from torch import nn as nn
from torch.nn import functional as F
_USE_MEM_EFFICIENT_ISH = True
if _USE_MEM_EFFICIENT_ISH:
# This version reduces memory overhead of Swish during training by
# recomputing torch.sigmoid(x) in backward instead of saving it.
@torch.jit.script
def swish_jit_fwd(x):
return x.mul(torch.sigmoid(x))
@torch.jit.script
def swish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
class SwishJitAutoFn(torch.autograd.Function):
""" torch.jit.script optimised Swish
Inspired by conversation btw Jeremy Howard & Adam Pazske
https://twitter.com/jeremyphoward/status/1188251041835315200
"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return swish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return swish_jit_bwd(x, grad_output)
def swish(x, _inplace=False):
return SwishJitAutoFn.apply(x)
@torch.jit.script
def mish_jit_fwd(x):
return x.mul(torch.tanh(F.softplus(x)))
@torch.jit.script
def mish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
x_tanh_sp = F.softplus(x).tanh()
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
class MishJitAutoFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return mish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return mish_jit_bwd(x, grad_output)
def mish(x, _inplace=False):
return MishJitAutoFn.apply(x)
else:
def swish(x, inplace: bool = False):
"""Swish - Described in: https://arxiv.org/abs/1710.05941
"""
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
def mish(x, _inplace: bool = False):
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
"""
return x.mul(F.softplus(x).tanh())
def swish(x, inplace: bool = False):
"""Swish - Described in: https://arxiv.org/abs/1710.05941
"""
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
class Swish(nn.Module):
@ -96,13 +26,21 @@ class Swish(nn.Module):
return swish(x, self.inplace)
def mish(x, inplace: bool = False):
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
NOTE: I don't have a working inplace variant
"""
return x.mul(F.softplus(x).tanh())
class Mish(nn.Module):
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
"""
def __init__(self, inplace: bool = False):
super(Mish, self).__init__()
self.inplace = inplace
def forward(self, x):
return mish(x, self.inplace)
return mish(x)
def sigmoid(x, inplace: bool = False):
@ -162,3 +100,22 @@ class HardSigmoid(nn.Module):
def forward(self, x):
return hard_sigmoid(x, self.inplace)
def hard_mish(x, inplace: bool = False):
""" Hard Mish
Experimental, based on notes by Mish author Diganta Misra at
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
"""
if inplace:
return x.mul_(0.5 * (x + 2).clamp(min=0, max=2))
else:
return 0.5 * x * (x + 2).clamp(min=0, max=2)
class HardMish(nn.Module):
def __init__(self, inplace: bool = False):
super(HardMish, self).__init__()
self.inplace = inplace
def forward(self, x):
return hard_mish(x, self.inplace)

@ -0,0 +1,90 @@
""" Activations
A collection of jit-scripted activations fn and modules with a common interface so that they can
easily be swapped. All have an `inplace` arg even if not used.
All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
versions if they contain in-place ops.
Hacked together by Ross Wightman
"""
import torch
from torch import nn as nn
from torch.nn import functional as F
@torch.jit.script
def swish_jit(x, inplace: bool = False):
"""Swish - Described in: https://arxiv.org/abs/1710.05941
"""
return x.mul(x.sigmoid())
@torch.jit.script
def mish_jit(x, _inplace: bool = False):
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
"""
return x.mul(F.softplus(x).tanh())
class SwishJit(nn.Module):
def __init__(self, inplace: bool = False):
super(SwishJit, self).__init__()
def forward(self, x):
return swish_jit(x)
class MishJit(nn.Module):
def __init__(self, inplace: bool = False):
super(MishJit, self).__init__()
def forward(self, x):
return mish_jit(x)
@torch.jit.script
def hard_sigmoid_jit(x, inplace: bool = False):
# return F.relu6(x + 3.) / 6.
return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
class HardSigmoidJit(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSigmoidJit, self).__init__()
def forward(self, x):
return hard_sigmoid_jit(x)
@torch.jit.script
def hard_swish_jit(x, inplace: bool = False):
# return x * (F.relu6(x + 3.) / 6)
return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
class HardSwishJit(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSwishJit, self).__init__()
def forward(self, x):
return hard_swish_jit(x)
@torch.jit.script
def hard_mish_jit(x, inplace: bool = False):
""" Hard Mish
Experimental, based on notes by Mish author Diganta Misra at
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
"""
return 0.5 * x * (x + 2).clamp(min=0, max=2)
class HardMishJit(nn.Module):
def __init__(self, inplace: bool = False):
super(HardMishJit, self).__init__()
def forward(self, x):
return hard_mish_jit(x)

@ -0,0 +1,208 @@
""" Activations (memory-efficient w/ custom autograd)
A collection of activations fn and modules with a common interface so that they can
easily be swapped. All have an `inplace` arg even if not used.
These activations are not compatible with jit scripting or ONNX export of the model, please use either
the JIT or basic versions of the activations.
Hacked together by Ross Wightman
"""
import torch
from torch import nn as nn
from torch.nn import functional as F
@torch.jit.script
def swish_jit_fwd(x):
return x.mul(torch.sigmoid(x))
@torch.jit.script
def swish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
class SwishJitAutoFn(torch.autograd.Function):
""" torch.jit.script optimised Swish w/ memory-efficient checkpoint
Inspired by conversation btw Jeremy Howard & Adam Pazske
https://twitter.com/jeremyphoward/status/1188251041835315200
"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return swish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return swish_jit_bwd(x, grad_output)
def swish_me(x, inplace=False):
return SwishJitAutoFn.apply(x)
class SwishMe(nn.Module):
def __init__(self, inplace: bool = False):
super(SwishMe, self).__init__()
def forward(self, x):
return SwishJitAutoFn.apply(x)
@torch.jit.script
def mish_jit_fwd(x):
return x.mul(torch.tanh(F.softplus(x)))
@torch.jit.script
def mish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
x_tanh_sp = F.softplus(x).tanh()
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
class MishJitAutoFn(torch.autograd.Function):
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
A memory efficient, jit scripted variant of Mish
"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return mish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return mish_jit_bwd(x, grad_output)
def mish_me(x, inplace=False):
return MishJitAutoFn.apply(x)
class MishMe(nn.Module):
def __init__(self, inplace: bool = False):
super(MishMe, self).__init__()
def forward(self, x):
return MishJitAutoFn.apply(x)
@torch.jit.script
def hard_sigmoid_jit_fwd(x, inplace: bool = False):
return (x + 3).clamp(min=0, max=6).div(6.)
@torch.jit.script
def hard_sigmoid_jit_bwd(x, grad_output):
m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
return grad_output * m
class HardSigmoidJitAutoFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return hard_sigmoid_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return hard_sigmoid_jit_bwd(x, grad_output)
def hard_sigmoid_me(x, inplace: bool = False):
return HardSigmoidJitAutoFn.apply(x)
class HardSigmoidMe(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSigmoidMe, self).__init__()
def forward(self, x):
return HardSigmoidJitAutoFn.apply(x)
@torch.jit.script
def hard_swish_jit_fwd(x):
return x * (x + 3).clamp(min=0, max=6).div(6.)
@torch.jit.script
def hard_swish_jit_bwd(x, grad_output):
m = torch.ones_like(x) * (x >= 3.)
m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
return grad_output * m
class HardSwishJitAutoFn(torch.autograd.Function):
"""A memory efficient, jit-scripted HardSwish activation"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return hard_swish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return hard_swish_jit_bwd(x, grad_output)
def hard_swish_me(x, inplace=False):
return HardSwishJitAutoFn.apply(x)
class HardSwishMe(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSwishMe, self).__init__()
def forward(self, x):
return HardSwishJitAutoFn.apply(x)
@torch.jit.script
def hard_mish_jit_fwd(x):
return 0.5 * x * (x + 2).clamp(min=0, max=2)
@torch.jit.script
def hard_mish_jit_bwd(x, grad_output):
m = torch.ones_like(x) * (x >= -2.)
m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
return grad_output * m
class HardMishJitAutoFn(torch.autograd.Function):
""" A memory efficient, jit scripted variant of Hard Mish
Experimental, based on notes by Mish author Diganta Misra at
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return mish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return mish_jit_bwd(x, grad_output)
def hard_mish_me(x, inplace: bool = False):
return HardMishJitAutoFn.apply(x)
class HardMishMe(nn.Module):
def __init__(self, inplace: bool = False):
super(HardMishMe, self).__init__()
def forward(self, x):
return HardMishJitAutoFn.apply(x)

@ -15,7 +15,7 @@ from torch.nn import functional as F
from .helpers import tup_pair
from .conv2d_same import conv2d_same
from timm.models.layers.padding import get_padding_value
from .padding import get_padding_value
def get_condconv_initializer(initializer, num_experts, expert_shape):

@ -0,0 +1,74 @@
""" Model / Layer Config Singleton
"""
from typing import Any
__all__ = ['is_exportable', 'is_scriptable', 'set_exportable', 'set_scriptable', 'is_no_jit', 'set_no_jit']
# Set to True if prefer to have layers with no jit optimization (includes activations)
_NO_JIT = False
# Set to True if prefer to have activation layers with no jit optimization
_NO_ACTIVATION_JIT = False
# Set to True if exporting a model with Same padding via ONNX
_EXPORTABLE = False
# Set to True if wanting to use torch.jit.script on a model
_SCRIPTABLE = False
def is_no_jit():
return _NO_JIT
class set_no_jit:
def __init__(self, mode: bool) -> None:
global _NO_JIT
self.prev = _NO_JIT
_NO_JIT = mode
def __enter__(self) -> None:
pass
def __exit__(self, *args: Any) -> bool:
global _NO_JIT
_NO_JIT = self.prev
return False
def is_exportable():
return _EXPORTABLE
class set_exportable:
def __init__(self, mode: bool) -> None:
global _EXPORTABLE
self.prev = _EXPORTABLE
_EXPORTABLE = mode
def __enter__(self) -> None:
pass
def __exit__(self, *args: Any) -> bool:
global _EXPORTABLE
_EXPORTABLE = self.prev
return False
def is_scriptable():
return _SCRIPTABLE
class set_scriptable:
def __init__(self, mode: bool) -> None:
global _SCRIPTABLE
self.prev = _SCRIPTABLE
_SCRIPTABLE = mode
def __enter__(self) -> None:
pass
def __exit__(self, *args: Any) -> bool:
global _SCRIPTABLE
_SCRIPTABLE = self.prev
return False

@ -7,8 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
from timm.models.layers.padding import get_padding_value
from .padding import pad_same
from .padding import pad_same, get_padding_value
def conv2d_same(

@ -4,33 +4,28 @@ Hacked together by Ross Wightman
"""
from torch import nn as nn
from timm.models.layers import get_padding
from .create_conv2d import create_conv2d
from .create_norm_act import convert_norm_act_type
class ConvBnAct(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, groups=1,
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
norm_layer=nn.BatchNorm2d, norm_kwargs=None, act_layer=nn.ReLU, apply_act=True,
drop_block=None, aa_layer=None):
super(ConvBnAct, self).__init__()
padding = get_padding(kernel_size, stride, dilation) # assuming PyTorch style padding for this block
use_aa = aa_layer is not None
self.conv = nn.Conv2d(
in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1 if use_aa else stride,
self.conv = create_conv2d(
in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
padding=padding, dilation=dilation, groups=groups, bias=False)
self.bn = norm_layer(out_channels)
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs)
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args)
self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None
self.drop_block = drop_block
if act_layer is not None:
self.act = act_layer(inplace=True)
else:
self.act = None
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.drop_block is not None:
x = self.drop_block(x)
if self.act is not None:
x = self.act(x)
if self.aa is not None:
x = self.aa(x)
return x

@ -0,0 +1,103 @@
from .activations import *
from .activations_jit import *
from .activations_me import *
from .config import is_exportable, is_scriptable, is_no_jit
_ACT_FN_DEFAULT = dict(
swish=swish,
mish=mish,
relu=F.relu,
relu6=F.relu6,
sigmoid=sigmoid,
tanh=tanh,
hard_sigmoid=hard_sigmoid,
hard_swish=hard_swish,
hard_mish=hard_mish,
)
_ACT_FN_JIT = dict(
swish=swish_jit,
mish=mish_jit,
hard_sigmoid=hard_sigmoid_jit,
hard_swish=hard_swish_jit,
hard_mish=hard_mish_jit
)
_ACT_FN_ME = dict(
swish=swish_me,
mish=mish_me,
hard_sigmoid=hard_sigmoid_me,
hard_swish=hard_swish_me,
hard_mish=hard_mish_me,
)
_ACT_LAYER_DEFAULT = dict(
swish=Swish,
mish=Mish,
relu=nn.ReLU,
relu6=nn.ReLU6,
sigmoid=Sigmoid,
tanh=Tanh,
hard_sigmoid=HardSigmoid,
hard_swish=HardSwish,
hard_mish=HardMish,
)
_ACT_LAYER_JIT = dict(
swish=SwishJit,
mish=MishJit,
hard_sigmoid=HardSigmoidJit,
hard_swish=HardSwishJit,
hard_mish=HardMishJit
)
_ACT_LAYER_ME = dict(
swish=SwishMe,
mish=MishMe,
hard_sigmoid=HardSigmoidMe,
hard_swish=HardSwishMe,
hard_mish=HardMishMe,
)
def get_act_fn(name='relu'):
""" Activation Function Factory
Fetching activation fns by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
"""
if not name:
return None
if not (is_no_jit() or is_exportable() or is_scriptable()):
# If not exporting or scripting the model, first look for a memory-efficient version with
# custom autograd, then fallback
if name in _ACT_FN_ME:
return _ACT_FN_ME[name]
if not is_no_jit():
if name in _ACT_FN_JIT:
return _ACT_FN_JIT[name]
return _ACT_FN_DEFAULT[name]
def get_act_layer(name='relu'):
""" Activation Layer Factory
Fetching activation layers by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
"""
if not name:
return None
if not (is_no_jit() or is_exportable() or is_scriptable()):
if name in _ACT_LAYER_ME:
return _ACT_LAYER_ME[name]
if not is_no_jit():
if name in _ACT_LAYER_JIT:
return _ACT_LAYER_JIT[name]
return _ACT_LAYER_DEFAULT[name]
def create_act_layer(name, inplace=False, **kwargs):
act_layer = get_act_layer(name)
if act_layer is not None:
return act_layer(inplace=inplace, **kwargs)
else:
return None

@ -3,7 +3,7 @@
Hacked together by Ross Wightman
"""
import torch
from .se import SEModule
from .se import SEModule, EffectiveSEModule
from .eca import EcaModule, CecaModule
from .cbam import CbamModule, LightCbamModule
@ -15,6 +15,8 @@ def create_attn(attn_type, channels, **kwargs):
attn_type = attn_type.lower()
if attn_type == 'se':
module_cls = SEModule
elif attn_type == 'ese':
module_cls = EffectiveSEModule
elif attn_type == 'eca':
module_cls = EcaModule
elif attn_type == 'ceca':

@ -8,23 +8,23 @@ from .cond_conv2d import CondConv2d
from .conv2d_same import create_conv2d_pad
def create_conv2d(in_chs, out_chs, kernel_size, **kwargs):
def create_conv2d(in_channels, out_channels, kernel_size, **kwargs):
""" Select a 2d convolution implementation based on arguments
Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d.
Used extensively by EfficientNet, MobileNetv3 and related networks.
"""
assert 'groups' not in kwargs # only use 'depthwise' bool arg
if isinstance(kernel_size, list):
assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
assert 'groups' not in kwargs # MixedConv groups are defined by kernel list
# We're going to use only lists for defining the MixedConv2d kernel groups,
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs)
else:
depthwise = kwargs.pop('depthwise', False)
groups = out_chs if depthwise else 1
groups = out_channels if depthwise else kwargs.pop('groups', 1)
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
else:
m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
return m

@ -1,37 +1,64 @@
import types
import functools
import torch
import torch.nn as nn
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .norm_act import BatchNormAct2d
try:
from inplace_abn import InPlaceABN
has_iabn = True
except ImportError:
has_iabn = False
from .norm_act import BatchNormAct2d, GroupNormAct
from .inplace_abn import InplaceAbn
_NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn}
def create_norm_act(layer_type, num_features, jit=False, **kwargs):
layer_parts = layer_type.split('_')
assert len(layer_parts) in (1, 2)
layer_class = layer_parts[0].lower()
#activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection
if layer_class == "batchnormact":
layer = BatchNormAct2d(num_features, **kwargs) # defaults to RELU of no kwargs override
elif layer_class == "batchnormrelu":
assert 'act_layer' not in kwargs
layer = BatchNormAct2d(num_features, act_layer=nn.ReLU, **kwargs)
def get_norm_act_layer(layer_class):
layer_class = layer_class.replace('_', '').lower()
if layer_class.startswith("batchnorm"):
layer = BatchNormAct2d
elif layer_class.startswith("groupnorm"):
layer = GroupNormAct
elif layer_class == "evonormbatch":
layer = EvoNormBatch2d(num_features, **kwargs)
layer = EvoNormBatch2d
elif layer_class == "evonormsample":
layer = EvoNormSample2d(num_features, **kwargs)
layer = EvoNormSample2d
elif layer_class == "iabn" or layer_class == "inplaceabn":
if not has_iabn:
raise ImportError(
"Pplease install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'")
layer = InPlaceABN(num_features, **kwargs)
layer = InplaceAbn
else:
assert False, "Invalid norm_act layer (%s)" % layer_class
if jit:
layer = torch.jit.script(layer)
return layer
def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs):
layer_parts = layer_type.split('-') # e.g. batchnorm-leaky_relu
assert len(layer_parts) in (1, 2)
layer = get_norm_act_layer(layer_parts[0])
#activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection?
layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
if jit:
layer_instance = torch.jit.script(layer_instance)
return layer_instance
def convert_norm_act_type(norm_layer, act_layer, norm_kwargs=None):
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial))
norm_act_args = norm_kwargs.copy() if norm_kwargs else {}
if isinstance(norm_layer, str):
norm_act_layer = get_norm_act_layer(norm_layer)
elif norm_layer in _NORM_ACT_TYPES:
norm_act_layer = norm_layer
elif isinstance(norm_layer, (types.FunctionType, functools.partial)):
# assuming this is a lambda/fn/bound partial that creates norm_act layer
norm_act_layer = norm_layer
else:
type_name = norm_layer.__name__.lower()
if type_name.startswith('batchnorm'):
norm_act_layer = BatchNormAct2d
elif type_name.startswith('groupnorm'):
norm_act_layer = GroupNormAct
else:
assert False, f"No equivalent norm_act layer for {type_name}"
# Must pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
# Newer models will use `apply_act` and likely have `act_layer` arg bound to relevant NormAct types.
norm_act_args.update(dict(act_layer=act_layer))
return norm_act_layer, norm_act_args

@ -17,8 +17,6 @@ Hacked together by Ross Wightman
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
def drop_block_2d(

@ -2,9 +2,9 @@
An attempt at getting decent performing EvoNorms running in PyTorch.
While currently faster than other impl, still quite a ways off the built-in BN
in terms of memory usage and throughput.
in terms of memory usage and throughput (roughly 5x mem, 1/2 - 1/3x speed).
Still very much a WIP, fiddling with buffer usage, in-place optimizations, and layouts.
Still very much a WIP, fiddling with buffer usage, in-place/jit optimizations, and layouts.
Hacked together by Ross Wightman
"""
@ -14,15 +14,15 @@ import torch.nn as nn
class EvoNormBatch2d(nn.Module):
def __init__(self, num_features, momentum=0.1, nonlin=True, eps=1e-5):
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None):
super(EvoNormBatch2d, self).__init__()
self.apply_act = apply_act # apply activation (non-linearity)
self.momentum = momentum
self.nonlin = nonlin
self.eps = eps
param_shape = (1, num_features, 1, 1)
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
if nonlin:
if apply_act:
self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.register_buffer('running_var', torch.ones(1, num_features, 1, 1))
self.reset_parameters()
@ -30,7 +30,7 @@ class EvoNormBatch2d(nn.Module):
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.nonlin:
if self.apply_act:
nn.init.ones_(self.v)
def forward(self, x):
@ -40,46 +40,42 @@ class EvoNormBatch2d(nn.Module):
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
self.running_var.copy_(self.momentum * var.detach() + (1 - self.momentum) * self.running_var)
else:
var = self.running_var.clone()
var = self.running_var
if self.nonlin:
if self.apply_act:
v = self.v.to(dtype=x_type)
d = (x * v) + x.var(dim=(2, 3), unbiased=False, keepdim=True).add_(self.eps).sqrt_().to(dtype=x_type)
d = d.max(var.add_(self.eps).sqrt_().to(dtype=x_type))
d = (x * v) + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type)
d = d.max((var + self.eps).sqrt().to(dtype=x_type))
x = x / d
return x.mul_(self.weight).add_(self.bias)
else:
return x.mul(self.weight).add_(self.bias)
return x * self.weight + self.bias
class EvoNormSample2d(nn.Module):
def __init__(self, num_features, nonlin=True, groups=8, eps=1e-5):
def __init__(self, num_features, apply_act=True, groups=8, eps=1e-5, drop_block=None):
super(EvoNormSample2d, self).__init__()
self.nonlin = nonlin
self.apply_act = apply_act # apply activation (non-linearity)
self.groups = groups
self.eps = eps
param_shape = (1, num_features, 1, 1)
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
if nonlin:
if apply_act:
self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.nonlin:
if self.apply_act:
nn.init.ones_(self.v)
def forward(self, x):
assert x.dim() == 4, 'expected 4D input'
B, C, H, W = x.shape
assert C % self.groups == 0
if self.nonlin:
if self.apply_act:
n = (x * self.v).sigmoid().reshape(B, self.groups, -1)
x = x.reshape(B, self.groups, -1)
x = n / x.var(dim=-1, unbiased=False, keepdim=True).add_(self.eps).sqrt_()
x = n / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt()
x = x.reshape(B, C, H, W)
return x.mul_(self.weight).add_(self.bias)
else:
return x.mul(self.weight).add_(self.bias)
return x * self.weight + self.bias

@ -0,0 +1,85 @@
import torch
from torch import nn as nn
try:
from inplace_abn.functions import inplace_abn, inplace_abn_sync
has_iabn = True
except ImportError:
has_iabn = False
def inplace_abn(x, weight, bias, running_mean, running_var,
training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01):
raise ImportError(
"Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'")
def inplace_abn_sync(**kwargs):
inplace_abn(**kwargs)
class InplaceAbn(nn.Module):
"""Activated Batch Normalization
This gathers a BatchNorm and an activation function in a single module
Parameters
----------
num_features : int
Number of feature channels in the input and output.
eps : float
Small constant to prevent numerical issues.
momentum : float
Momentum factor applied to compute running statistics.
affine : bool
If `True` apply learned scale and shift transformation after normalization.
act_layer : str or nn.Module type
Name or type of the activation functions, one of: `leaky_relu`, `elu`
act_param : float
Negative slope for the `leaky_relu` activation.
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True,
act_layer="leaky_relu", act_param=0.01, drop_block=None,):
super(InplaceAbn, self).__init__()
self.num_features = num_features
self.affine = affine
self.eps = eps
self.momentum = momentum
if apply_act:
if isinstance(act_layer, str):
assert act_layer in ('leaky_relu', 'elu', 'identity')
self.act_name = act_layer
else:
# convert act layer passed as type to string
if isinstance(act_layer, nn.ELU):
self.act_name = 'elu'
elif isinstance(act_layer, nn.LeakyReLU):
self.act_name = 'leaky_relu'
else:
assert False, f'Invalid act layer {act_layer.__name__} for IABN'
else:
self.act_name = 'identity'
self.act_param = act_param
if self.affine:
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.constant_(self.running_mean, 0)
nn.init.constant_(self.running_var, 1)
if self.affine:
nn.init.constant_(self.weight, 1)
nn.init.constant_(self.bias, 0)
def forward(self, x):
output = inplace_abn(
x, self.weight, self.bias, self.running_mean, self.running_var,
self.training, self.momentum, self.eps, self.act_name, self.act_param)
if isinstance(output, tuple):
output = output[0]
return output

@ -1,28 +1,33 @@
""" Normalization + Activation Layers
"""
import torch
from torch import nn as nn
from torch.nn import functional as F
from .create_act import get_act_layer
class BatchNormAct2d(nn.BatchNorm2d):
"""BatchNorm + Activation
This module performs BatchNorm + Actibation in s manner that will remain bavkwards
This module performs BatchNorm + Activation in a manner that will remain backwards
compatible with weights trained with separate bn, act. This is why we inherit from BN
instead of composing it as a .bn member.
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True, act_layer=nn.ReLU, inplace=True):
super(BatchNormAct2d, self).__init__(num_features, eps, momentum, affine, track_running_stats)
self.act = act_layer(inplace=inplace)
def forward(self, x):
# FIXME cannot call parent forward() and maintain jit.script compatibility?
# x = super(BatchNormAct2d, self).forward(x)
# BEGIN nn.BatchNorm2d forward() cut & paste
# self._check_input_dim(x)
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None):
super(BatchNormAct2d, self).__init__(
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
if isinstance(act_layer, str):
act_layer = get_act_layer(act_layer)
if act_layer is not None and apply_act:
self.act = act_layer(inplace=inplace)
else:
self.act = None
def _forward_jit(self, x):
""" A cut & paste of the contents of the PyTorch BatchNorm2d forward function
"""
# exponential_average_factor is self.momentum set to
# (when it is available) only so that if gets updated
# in ONNX graph when this node is exported to ONNX.
@ -41,10 +46,40 @@ class BatchNormAct2d(nn.BatchNorm2d):
exponential_average_factor = self.momentum
x = F.batch_norm(
x, self.running_mean, self.running_var, self.weight, self.bias,
self.training or not self.track_running_stats,
exponential_average_factor, self.eps)
# END BatchNorm2d forward()
x, self.running_mean, self.running_var, self.weight, self.bias,
self.training or not self.track_running_stats,
exponential_average_factor, self.eps)
return x
@torch.jit.ignore
def _forward_python(self, x):
return super(BatchNormAct2d, self).forward(x)
def forward(self, x):
# FIXME cannot call parent forward() and maintain jit.script compatibility?
if torch.jit.is_scripting():
x = self._forward_jit(x)
else:
self._forward_python(x)
if self.act is not None:
x = self.act(x)
return x
x = self.act(x)
class GroupNormAct(nn.GroupNorm):
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None):
super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine)
if isinstance(act_layer, str):
act_layer = get_act_layer(act_layer)
if act_layer is not None and apply_act:
self.act = act_layer(inplace=inplace)
else:
self.act = None
def forward(self, x):
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
if self.act is not None:
x = self.act(x)
return x

@ -6,7 +6,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union, List, Tuple, Optional
import math
from .helpers import tup_pair
from .padding import pad_same, get_padding_value

@ -1,9 +1,11 @@
from torch import nn as nn
from .create_act import get_act_fn
class SEModule(nn.Module):
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None):
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None,
gate_fn='hard_sigmoid'):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
reduction_channels = reduction_channels or max(channels // reduction, min_channels)
@ -12,10 +14,27 @@ class SEModule(nn.Module):
self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
self.gate_fn = get_act_fn(gate_fn)
def forward(self, x):
x_se = self.avg_pool(x)
x_se = self.fc1(x_se)
x_se = self.act(x_se)
x_se = self.fc2(x_se)
return x * x_se.sigmoid()
return x * self.gate_fn(x_se)
class EffectiveSEModule(nn.Module):
""" 'Effective Squeeze-Excitation
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
"""
def __init__(self, channel, gate_fn='hard_sigmoid'):
super(EffectiveSEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Conv2d(channel, channel, kernel_size=1, padding=0)
self.gate_fn = get_act_fn(gate_fn)
def forward(self, x):
x_se = self.avg_pool(x)
x_se = self.fc(x_se)
return x * self.gate_fn(x_se, inplace=True)

@ -4,7 +4,6 @@ Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
Hacked together by Ross Wightman
"""
import torch
from torch import nn as nn

@ -0,0 +1,51 @@
from torch import nn as nn
from .create_conv2d import create_conv2d
from .create_norm_act import convert_norm_act_type
class SeparableConvBnAct(nn.Module):
""" Separable Conv w/ trailing Norm and Activation
"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
act_layer=nn.ReLU, apply_act=True, drop_block=None):
super(SeparableConvBnAct, self).__init__()
norm_kwargs = norm_kwargs or {}
self.conv_dw = create_conv2d(
in_channels, int(in_channels * channel_multiplier), kernel_size,
stride=stride, dilation=dilation, padding=padding, depthwise=True)
self.conv_pw = create_conv2d(
int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs)
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args)
def forward(self, x):
x = self.conv_dw(x)
x = self.conv_pw(x)
if self.bn is not None:
x = self.bn(x)
return x
class SeparableConv2d(nn.Module):
""" Separable Conv
"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
channel_multiplier=1.0, pw_kernel_size=1):
super(SeparableConv2d, self).__init__()
self.conv_dw = create_conv2d(
in_channels, int(in_channels * channel_multiplier), kernel_size,
stride=stride, dilation=dilation, padding=padding, depthwise=True)
self.conv_pw = create_conv2d(
int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
def forward(self, x):
x = self.conv_dw(x)
x = self.conv_pw(x)
return x

@ -6,6 +6,7 @@ Hacked together by Ross Wightman
import logging
from torch import nn
import torch.nn.functional as F
from .adaptive_avgmax_pool import adaptive_avgmax_pool2d

@ -7,13 +7,15 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
Hacked together by Ross Wightman
"""
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .efficientnet_builder import *
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
from .feature_hooks import FeatureHooks
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d, create_conv2d
from .layers.activations import HardSwish, hard_sigmoid
from .layers import SelectAdaptivePool2d, create_conv2d, get_act_fn, hard_sigmoid
from .registry import register_model
__all__ = ['MobileNetV3']
@ -273,8 +275,8 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw
head_bias=False,
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
act_layer=HardSwish,
se_kwargs=dict(gate_fn=hard_sigmoid, reduce_mid=True, divisor=1),
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=1),
**kwargs,
)
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
@ -293,7 +295,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
if 'small' in variant:
num_features = 1024
if 'minimal' in variant:
act_layer = nn.ReLU
act_layer = resolve_act_layer(kwargs, 'relu')
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s2_e1_c16'],
@ -309,7 +311,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
['cn_r1_k1_s1_c576'],
]
else:
act_layer = HardSwish
act_layer = resolve_act_layer(kwargs, 'hard_swish')
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu
@ -327,7 +329,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
else:
num_features = 1280
if 'minimal' in variant:
act_layer = nn.ReLU
act_layer = resolve_act_layer(kwargs, 'relu')
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16'],
@ -345,7 +347,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
['cn_r1_k1_s1_c960'],
]
else:
act_layer = HardSwish
act_layer = resolve_act_layer(kwargs, 'hard_swish')
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16_nre'], # relu

@ -43,11 +43,12 @@ class MaxPool(nn.Module):
self.pool = nn.MaxPool2d(kernel_size, stride=stride, padding=padding)
def forward(self, x):
if self.zero_pad:
if self.zero_pad is not None:
x = self.zero_pad(x)
x = self.pool(x)
if self.zero_pad:
x = self.pool(x)
x = x[:, :, 1:, 1:]
else:
x = self.pool(x)
return x
@ -90,11 +91,12 @@ class BranchSeparables(nn.Module):
def forward(self, x):
x = self.relu_1(x)
if self.zero_pad:
if self.zero_pad is not None:
x = self.zero_pad(x)
x = self.separable_1(x)
if self.zero_pad:
x = self.separable_1(x)
x = x[:, :, 1:, 1:].contiguous()
else:
x = self.separable_1(x)
x = self.bn_sep_1(x)
x = self.relu_2(x)
x = self.separable_2(x)
@ -171,15 +173,14 @@ class CellBase(nn.Module):
x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
x_comb_iter_4_left = self.comb_iter_4_left(x_left)
if self.comb_iter_4_right:
if self.comb_iter_4_right is not None:
x_comb_iter_4_right = self.comb_iter_4_right(x_right)
else:
x_comb_iter_4_right = x_right
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
x_out = torch.cat(
[x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3,
x_comb_iter_4], 1)
[x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
return x_out
@ -280,9 +281,8 @@ class Cell(CellBase):
kernel_size=3, stride=stride,
zero_pad=zero_pad)
if is_reduction:
self.comb_iter_4_right = ReluConvBn(out_channels_right,
out_channels_right,
kernel_size=1, stride=stride)
self.comb_iter_4_right = ReluConvBn(
out_channels_right, out_channels_right, kernel_size=1, stride=stride)
else:
self.comb_iter_4_right = None

@ -77,6 +77,8 @@ class Bottle2neck(nn.Module):
if self.is_first:
# FIXME this should probably have count_include_pad=False, but hurts original weights
self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
else:
self.pool = None
self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes)
@ -97,14 +99,22 @@ class Bottle2neck(nn.Module):
spx = torch.split(out, self.width, 1)
spo = []
sp = spx[0]
for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
sp = spx[i] if i == 0 or self.is_first else sp + spx[i]
if self.is_first:
sp = spx[i]
else:
sp = sp + spx[i]
sp = conv(sp)
sp = bn(sp)
sp = self.relu(sp)
spo.append(sp)
if self.scale > 1:
spo.append(self.pool(spx[-1]) if self.is_first else spx[-1])
if self.pool is not None:
# self.is_first == True, None check for torchscript
spo.append(self.pool(spx[-1]))
else:
spo.append(spx[-1])
out = torch.cat(spo, 1)
out = self.conv3(out)

@ -200,7 +200,6 @@ class BasicBlock(nn.Module):
class Bottleneck(nn.Module):
__constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,

@ -9,6 +9,7 @@ https://arxiv.org/abs/1907.00837
Based on ResNet implementation in https://github.com/rwightman/pytorch-image-models
and SelecSLS Net implementation in https://github.com/mehtadushy/SelecSLS-Pytorch
"""
from typing import List
import torch
import torch.nn as nn
@ -52,6 +53,27 @@ default_cfgs = {
}
class SequentialList(nn.Sequential):
def __init__(self, *args):
super(SequentialList, self).__init__(*args)
@torch.jit._overload_method # noqa: F811
def forward(self, x):
# type: (List[torch.Tensor]) -> (List[torch.Tensor])
pass
@torch.jit._overload_method # noqa: F811
def forward(self, x):
# type: (torch.Tensor) -> (List[torch.Tensor])
pass
def forward(self, x) -> List[torch.Tensor]:
for module in self:
x = module(x)
return x
def conv_bn(in_chs, out_chs, k=3, stride=1, padding=None, dilation=1):
if padding is None:
padding = ((stride - 1) + dilation * (k - 1)) // 2
@ -77,7 +99,7 @@ class SelecSLSBlock(nn.Module):
self.conv5 = conv_bn(mid_chs, mid_chs // 2, 3)
self.conv6 = conv_bn(2 * mid_chs + (0 if is_first else skip_chs), out_chs, 1)
def forward(self, x):
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
assert isinstance(x, list)
assert len(x) in [1, 2]
@ -113,7 +135,7 @@ class SelecSLS(nn.Module):
super(SelecSLS, self).__init__()
self.stem = conv_bn(in_chans, 32, stride=2)
self.features = nn.Sequential(*[cfg['block'](*block_args) for block_args in cfg['features']])
self.features = SequentialList(*[cfg['block'](*block_args) for block_args in cfg['features']])
self.head = nn.Sequential(*[conv_bn(*conv_args) for conv_args in cfg['head']])
self.num_features = cfg['num_features']

@ -13,15 +13,9 @@ import torch.nn as nn
import torch.nn.functional as F
from .helpers import load_pretrained
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, SelectAdaptivePool2d
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, SelectAdaptivePool2d, InplaceAbn
from .registry import register_model
try:
from inplace_abn import InPlaceABN
has_iabn = True
except ImportError:
has_iabn = False
__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl']
@ -91,37 +85,37 @@ class FastSEModule(nn.Module):
def IABN2Float(module: nn.Module) -> nn.Module:
"""If `module` is IABN don't use half precision."""
if isinstance(module, InPlaceABN):
if isinstance(module, InplaceAbn):
module.float()
for child in module.children():
IABN2Float(child)
return module
def conv2d_ABN(ni, nf, stride, activation="leaky_relu", kernel_size=3, activation_param=1e-2, groups=1):
def conv2d_iabn(ni, nf, stride, kernel_size=3, groups=1, act_layer="leaky_relu", act_param=1e-2):
return nn.Sequential(
nn.Conv2d(
ni, nf, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=groups, bias=False),
InPlaceABN(num_features=nf, activation=activation, activation_param=activation_param)
InplaceAbn(nf, act_layer=act_layer, act_param=act_param)
)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, anti_alias_layer=None):
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, aa_layer=None):
super(BasicBlock, self).__init__()
if stride == 1:
self.conv1 = conv2d_ABN(inplanes, planes, stride=1, activation_param=1e-3)
self.conv1 = conv2d_iabn(inplanes, planes, stride=1, act_param=1e-3)
else:
if anti_alias_layer is None:
self.conv1 = conv2d_ABN(inplanes, planes, stride=2, activation_param=1e-3)
if aa_layer is None:
self.conv1 = conv2d_iabn(inplanes, planes, stride=2, act_param=1e-3)
else:
self.conv1 = nn.Sequential(
conv2d_ABN(inplanes, planes, stride=1, activation_param=1e-3),
anti_alias_layer(channels=planes, filt_size=3, stride=2))
conv2d_iabn(inplanes, planes, stride=1, act_param=1e-3),
aa_layer(channels=planes, filt_size=3, stride=2))
self.conv2 = conv2d_ABN(planes, planes, stride=1, activation="identity")
self.conv2 = conv2d_iabn(planes, planes, stride=1, act_layer="identity")
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
@ -148,24 +142,25 @@ class BasicBlock(nn.Module):
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, anti_alias_layer=None):
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True,
act_layer="leaky_relu", aa_layer=None):
super(Bottleneck, self).__init__()
self.conv1 = conv2d_ABN(
inplanes, planes, kernel_size=1, stride=1, activation="leaky_relu", activation_param=1e-3)
self.conv1 = conv2d_iabn(
inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer, act_param=1e-3)
if stride == 1:
self.conv2 = conv2d_ABN(
planes, planes, kernel_size=3, stride=1, activation="leaky_relu", activation_param=1e-3)
self.conv2 = conv2d_iabn(
planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3)
else:
if anti_alias_layer is None:
self.conv2 = conv2d_ABN(
planes, planes, kernel_size=3, stride=2, activation="leaky_relu", activation_param=1e-3)
if aa_layer is None:
self.conv2 = conv2d_iabn(
planes, planes, kernel_size=3, stride=2, act_layer=act_layer, act_param=1e-3)
else:
self.conv2 = nn.Sequential(
conv2d_ABN(planes, planes, kernel_size=3, stride=1, activation="leaky_relu", activation_param=1e-3),
anti_alias_layer(channels=planes, filt_size=3, stride=2))
conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3),
aa_layer(channels=planes, filt_size=3, stride=2))
self.conv3 = conv2d_ABN(
planes, planes * self.expansion, kernel_size=1, stride=1, activation="identity")
self.conv3 = conv2d_iabn(
planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity")
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
@ -195,30 +190,26 @@ class Bottleneck(nn.Module):
class TResNet(nn.Module):
def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False,
global_pool='avg', drop_rate=0.):
if not has_iabn:
raise ImportError(
"For TResNet models, please install InplaceABN: "
"'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'")
self.num_classes = num_classes
self.drop_rate = drop_rate
super(TResNet, self).__init__()
# JIT layers
space_to_depth = SpaceToDepthModule()
anti_alias_layer = partial(AntiAliasDownsampleLayer, no_jit=no_aa_jit)
aa_layer = partial(AntiAliasDownsampleLayer, no_jit=no_aa_jit)
# TResnet stages
self.inplanes = int(64 * width_factor)
self.planes = int(64 * width_factor)
conv1 = conv2d_ABN(in_chans * 16, self.planes, stride=1, kernel_size=3)
conv1 = conv2d_iabn(in_chans * 16, self.planes, stride=1, kernel_size=3)
layer1 = self._make_layer(
BasicBlock, self.planes, layers[0], stride=1, use_se=True, anti_alias_layer=anti_alias_layer) # 56x56
BasicBlock, self.planes, layers[0], stride=1, use_se=True, aa_layer=aa_layer) # 56x56
layer2 = self._make_layer(
BasicBlock, self.planes * 2, layers[1], stride=2, use_se=True, anti_alias_layer=anti_alias_layer) # 28x28
BasicBlock, self.planes * 2, layers[1], stride=2, use_se=True, aa_layer=aa_layer) # 28x28
layer3 = self._make_layer(
Bottleneck, self.planes * 4, layers[2], stride=2, use_se=True, anti_alias_layer=anti_alias_layer) # 14x14
Bottleneck, self.planes * 4, layers[2], stride=2, use_se=True, aa_layer=aa_layer) # 14x14
layer4 = self._make_layer(
Bottleneck, self.planes * 8, layers[3], stride=2, use_se=False, anti_alias_layer=anti_alias_layer) # 7x7
Bottleneck, self.planes * 8, layers[3], stride=2, use_se=False, aa_layer=aa_layer) # 7x7
# body
self.body = nn.Sequential(OrderedDict([
@ -239,7 +230,7 @@ class TResNet(nn.Module):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, InPlaceABN):
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, InplaceAbn):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
@ -251,24 +242,24 @@ class TResNet(nn.Module):
m.conv3[1].weight = nn.Parameter(torch.zeros_like(m.conv3[1].weight)) # BN to zero
if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01)
def _make_layer(self, block, planes, blocks, stride=1, use_se=True, anti_alias_layer=None):
def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=None):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
layers = []
if stride == 2:
# avg pooling before 1x1 conv
layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False))
layers += [conv2d_ABN(
self.inplanes, planes * block.expansion, kernel_size=1, stride=1, activation="identity")]
layers += [conv2d_iabn(
self.inplanes, planes * block.expansion, kernel_size=1, stride=1, act_layer="identity")]
downsample = nn.Sequential(*layers)
layers = []
layers.append(block(
self.inplanes, planes, stride, downsample, use_se=use_se, anti_alias_layer=anti_alias_layer))
self.inplanes, planes, stride, downsample, use_se=use_se, aa_layer=aa_layer))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(self.inplanes, planes, use_se=use_se, anti_alias_layer=anti_alias_layer))
block(self.inplanes, planes, use_se=use_se, aa_layer=aa_layer))
return nn.Sequential(*layers)
def get_classifier(self):

@ -0,0 +1,408 @@
""" VoVNet (V1 & V2)
Papers:
* `An Energy and GPU-Computation Efficient Backbone Network` - https://arxiv.org/abs/1904.09730
* `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
Looked at https://github.com/youngwanLEE/vovnet-detectron2 &
https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py
for some reference, rewrote most of the code.
Hacked together by Ross Wightman
"""
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .registry import register_model
from .helpers import load_pretrained
from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, SelectAdaptivePool2d, \
create_attn, create_norm_act, get_norm_act_layer
# model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 &
# https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py
model_cfgs = dict(
vovnet39a=dict(
stem_ch=[64, 64, 128],
stage_conv_ch=[128, 160, 192, 224],
stage_out_ch=[256, 512, 768, 1024],
layer_per_block=5,
block_per_stage=[1, 1, 2, 2],
residual=False,
depthwise=False,
attn='',
),
vovnet57a=dict(
stem_ch=[64, 64, 128],
stage_conv_ch=[128, 160, 192, 224],
stage_out_ch=[256, 512, 768, 1024],
layer_per_block=5,
block_per_stage=[1, 1, 4, 3],
residual=False,
depthwise=False,
attn='',
),
ese_vovnet19b_slim_dw=dict(
stem_ch=[64, 64, 64],
stage_conv_ch=[64, 80, 96, 112],
stage_out_ch=[112, 256, 384, 512],
layer_per_block=3,
block_per_stage=[1, 1, 1, 1],
residual=True,
depthwise=True,
attn='ese',
),
ese_vovnet19b_dw=dict(
stem_ch=[64, 64, 64],
stage_conv_ch=[128, 160, 192, 224],
stage_out_ch=[256, 512, 768, 1024],
layer_per_block=3,
block_per_stage=[1, 1, 1, 1],
residual=True,
depthwise=True,
attn='ese',
),
ese_vovnet19b_slim=dict(
stem_ch=[64, 64, 128],
stage_conv_ch=[64, 80, 96, 112],
stage_out_ch=[112, 256, 384, 512],
layer_per_block=3,
block_per_stage=[1, 1, 1, 1],
residual=True,
depthwise=False,
attn='ese',
),
ese_vovnet19b=dict(
stem_ch=[64, 64, 128],
stage_conv_ch=[128, 160, 192, 224],
stage_out_ch=[256, 512, 768, 1024],
layer_per_block=3,
block_per_stage=[1, 1, 1, 1],
residual=True,
depthwise=False,
attn='ese',
),
ese_vovnet39b=dict(
stem_ch=[64, 64, 128],
stage_conv_ch=[128, 160, 192, 224],
stage_out_ch=[256, 512, 768, 1024],
layer_per_block=5,
block_per_stage=[1, 1, 2, 2],
residual=True,
depthwise=False,
attn='ese',
),
ese_vovnet57b=dict(
stem_ch=[64, 64, 128],
stage_conv_ch=[128, 160, 192, 224],
stage_out_ch=[256, 512, 768, 1024],
layer_per_block=5,
block_per_stage=[1, 1, 4, 3],
residual=True,
depthwise=False,
attn='ese',
),
ese_vovnet99b=dict(
stem_ch=[64, 64, 128],
stage_conv_ch=[128, 160, 192, 224],
stage_out_ch=[256, 512, 768, 1024],
layer_per_block=5,
block_per_stage=[1, 3, 9, 3],
residual=True,
depthwise=False,
attn='ese',
),
eca_vovnet39b=dict(
stem_ch=[64, 64, 128],
stage_conv_ch=[128, 160, 192, 224],
stage_out_ch=[256, 512, 768, 1024],
layer_per_block=5,
block_per_stage=[1, 1, 2, 2],
residual=True,
depthwise=False,
attn='eca',
),
)
def _cfg(url=''):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
}
default_cfgs = dict(
vovnet39a=_cfg(url=''),
vovnet57a=_cfg(url=''),
ese_vovnet19b_slim_dw=_cfg(url=''),
ese_vovnet19b_dw=_cfg(url=''),
ese_vovnet19b_slim=_cfg(url=''),
ese_vovnet39b=_cfg(url=''),
ese_vovnet57b=_cfg(url=''),
ese_vovnet99b=_cfg(url=''),
eca_vovnet39b=_cfg(url=''),
)
class SequentialAppendList(nn.Sequential):
def __init__(self, *args):
super(SequentialAppendList, self).__init__(*args)
def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor:
for i, module in enumerate(self):
if i == 0:
concat_list.append(module(x))
else:
concat_list.append(module(concat_list[-1]))
x = torch.cat(concat_list, dim=1)
return x
class OsaBlock(nn.Module):
def __init__(self, in_chs, mid_chs, out_chs, layer_per_block, residual=False,
depthwise=False, attn='', norm_layer=BatchNormAct2d):
super(OsaBlock, self).__init__()
self.residual = residual
self.depthwise = depthwise
next_in_chs = in_chs
if self.depthwise and next_in_chs != mid_chs:
assert not residual
self.conv_reduction = ConvBnAct(next_in_chs, mid_chs, 1, norm_layer=norm_layer)
else:
self.conv_reduction = None
mid_convs = []
for i in range(layer_per_block):
if self.depthwise:
conv = SeparableConvBnAct(mid_chs, mid_chs, norm_layer=norm_layer)
else:
conv = ConvBnAct(next_in_chs, mid_chs, 3, norm_layer=norm_layer)
next_in_chs = mid_chs
mid_convs.append(conv)
self.conv_mid = SequentialAppendList(*mid_convs)
# feature aggregation
next_in_chs = in_chs + layer_per_block * mid_chs
self.conv_concat = ConvBnAct(next_in_chs, out_chs, norm_layer=norm_layer)
if attn:
self.attn = create_attn(attn, out_chs)
else:
self.attn = None
def forward(self, x):
output = [x]
if self.conv_reduction is not None:
x = self.conv_reduction(x)
x = self.conv_mid(x, output)
x = self.conv_concat(x)
if self.attn is not None:
x = self.attn(x)
if self.residual:
x = x + output[0]
return x
class OsaStage(nn.Module):
def __init__(self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block,
downsample=True, residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d):
super(OsaStage, self).__init__()
if downsample:
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
else:
self.pool = None
blocks = []
for i in range(block_per_stage):
last_block = i == block_per_stage - 1
blocks += [OsaBlock(
in_chs if i == 0 else out_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0,
depthwise=depthwise, attn=attn if last_block else '', norm_layer=norm_layer)
]
self.blocks = nn.Sequential(*blocks)
def forward(self, x):
if self.pool is not None:
x = self.pool(x)
x = self.blocks(x)
return x
class ClassifierHead(nn.Module):
"""Head."""
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.):
super(ClassifierHead, self).__init__()
self.drop_rate = drop_rate
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
if num_classes > 0:
self.fc = nn.Linear(in_chs, num_classes, bias=True)
else:
self.fc = nn.Identity()
def forward(self, x):
x = self.global_pool(x).flatten(1)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x)
return x
class VovNet(nn.Module):
def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4,
norm_layer=BatchNormAct2d):
""" VovNet (v2)
"""
super(VovNet, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
assert stem_stride in (4, 2)
stem_ch = cfg["stem_ch"]
stage_conv_ch = cfg["stage_conv_ch"]
stage_out_ch = cfg["stage_out_ch"]
block_per_stage = cfg["block_per_stage"]
layer_per_block = cfg["layer_per_block"]
# Stem module
last_stem_stride = stem_stride // 2
conv_type = SeparableConvBnAct if cfg["depthwise"] else ConvBnAct
self.stem = nn.Sequential(*[
ConvBnAct(in_chans, stem_ch[0], 3, stride=2, norm_layer=norm_layer),
conv_type(stem_ch[0], stem_ch[1], 3, stride=1, norm_layer=norm_layer),
conv_type(stem_ch[1], stem_ch[2], 3, stride=last_stem_stride, norm_layer=norm_layer),
])
# OSA stages
in_ch_list = stem_ch[-1:] + stage_out_ch[:-1]
stage_args = dict(
residual=cfg["residual"], depthwise=cfg["depthwise"], attn=cfg["attn"], norm_layer=norm_layer)
stages = []
for i in range(4): # num_stages
downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4
stages += [OsaStage(
in_ch_list[i], stage_conv_ch[i], stage_out_ch[i], block_per_stage[i], layer_per_block,
downsample=downsample, **stage_args)
]
self.num_features = stage_out_ch[i]
self.stages = nn.Sequential(*stages)
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
for n, m in self.named_modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.)
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features(self, x):
x = self.stem(x)
return self.stages(x)
def forward(self, x):
x = self.forward_features(x)
return self.head(x)
def _vovnet(variant, pretrained=False, **kwargs):
load_strict = True
model_class = VovNet
if kwargs.pop('features_only', False):
assert False, 'Not Implemented' # TODO
load_strict = False
kwargs.pop('num_classes', 0)
model_cfg = model_cfgs[variant]
default_cfg = default_cfgs[variant]
model = model_class(model_cfg, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(
model, default_cfg,
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict)
return model
@register_model
def vovnet39a(pretrained=False, **kwargs):
return _vovnet('vovnet39a', pretrained=pretrained, **kwargs)
@register_model
def vovnet57a(pretrained=False, **kwargs):
return _vovnet('vovnet57a', pretrained=pretrained, **kwargs)
@register_model
def ese_vovnet19b_slim_dw(pretrained=False, **kwargs):
return _vovnet('ese_vovnet19b_slim_dw', pretrained=pretrained, **kwargs)
@register_model
def ese_vovnet19b_dw(pretrained=False, **kwargs):
return _vovnet('ese_vovnet19b_dw', pretrained=pretrained, **kwargs)
@register_model
def ese_vovnet19b_slim(pretrained=False, **kwargs):
return _vovnet('ese_vovnet19b_slim', pretrained=pretrained, **kwargs)
@register_model
def ese_vovnet39b(pretrained=False, **kwargs):
return _vovnet('ese_vovnet39b', pretrained=pretrained, **kwargs)
@register_model
def ese_vovnet57b(pretrained=False, **kwargs):
return _vovnet('ese_vovnet57b', pretrained=pretrained, **kwargs)
@register_model
def ese_vovnet99b(pretrained=False, **kwargs):
return _vovnet('ese_vovnet99b', pretrained=pretrained, **kwargs)
@register_model
def eca_vovnet39b(pretrained=False, **kwargs):
return _vovnet('eca_vovnet39b', pretrained=pretrained, **kwargs)
# Experimental Models
@register_model
def ese_vovnet39b_iabn(pretrained=False, **kwargs):
norm_layer = get_norm_act_layer('iabn')
return _vovnet('ese_vovnet39b', pretrained=pretrained, norm_layer=norm_layer, **kwargs)
@register_model
def ese_vovnet39b_evos(pretrained=False, **kwargs):
def norm_act_fn(num_features, **kwargs):
return create_norm_act('EvoNormSample', num_features, jit=False, **kwargs)
return _vovnet('ese_vovnet39b', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs)

@ -24,7 +24,8 @@ try:
except ImportError:
has_apex = False
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models,\
set_scriptable, set_no_jit
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
@ -84,6 +85,9 @@ def validate(args):
args.pretrained = args.pretrained or not args.checkpoint
args.prefetcher = not args.no_prefetcher
if args.torchscript:
set_scriptable(True)
# create model
model = create_model(
args.model,
@ -141,8 +145,10 @@ def validate(args):
top5 = AverageMeter()
model.eval()
end = time.time()
with torch.no_grad():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
model(torch.randn((args.batch_size,) + data_config['input_size']).cuda())
end = time.time()
for i, (input, target) in enumerate(loader):
if args.no_prefetcher:
target = target.cuda()

Loading…
Cancel
Save