Fix #661, move hardswish out of default args for LeViT. Enable native torch support for hardswish, hardsigmoid, mish if present.

pull/669/head
Ross Wightman 3 years ago
parent 07d952c7a7
commit 9c78de8c02

@ -8,10 +8,10 @@ from timm.models.layers import create_act_layer, get_act_layer, set_layer_config
class MLP(nn.Module):
def __init__(self, act_layer="relu"):
def __init__(self, act_layer="relu", inplace=True):
super(MLP, self).__init__()
self.fc1 = nn.Linear(1000, 100)
self.act = create_act_layer(act_layer, inplace=True)
self.act = create_act_layer(act_layer, inplace=inplace)
self.fc2 = nn.Linear(100, 10)
def forward(self, x):
@ -21,14 +21,14 @@ class MLP(nn.Module):
return x
def _run_act_layer_grad(act_type):
def _run_act_layer_grad(act_type, inplace=True):
x = torch.rand(10, 1000) * 10
m = MLP(act_layer=act_type)
m = MLP(act_layer=act_type, inplace=inplace)
def _run(x, act_layer=''):
if act_layer:
# replace act layer if set
m.act = create_act_layer(act_layer, inplace=True)
m.act = create_act_layer(act_layer, inplace=inplace)
out = m(x)
l = (out - 0).pow(2).sum()
return l
@ -58,7 +58,7 @@ def test_mish_grad():
def test_hard_sigmoid_grad():
for _ in range(100):
_run_act_layer_grad('hard_sigmoid')
_run_act_layer_grad('hard_sigmoid', inplace=None)
def test_hard_swish_grad():

@ -110,7 +110,7 @@ def test_model_backward(model_name, batch_size):
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.timeout(120)
@pytest.mark.timeout(300)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_default_cfgs(model_name, batch_size):

@ -7,7 +7,7 @@ import torch
import torch.nn as nn
from torch.nn import functional as F
from .layers import create_conv2d, drop_path, make_divisible
from .layers import create_conv2d, drop_path, make_divisible, get_act_fn, create_act_layer
from .layers.activations import sigmoid
__all__ = [
@ -36,9 +36,9 @@ class SqueezeExcite(nn.Module):
reduced_chs = make_divisible(reduced_chs * se_ratio, divisor)
act_layer = force_act_layer or act_layer
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
self.act1 = act_layer(inplace=True)
self.act1 = create_act_layer(act_layer, inplace=True)
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
self.gate_fn = gate_fn
self.gate_fn = get_act_fn(gate_fn)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)

@ -50,10 +50,7 @@ def resolve_bn_args(kwargs):
def resolve_act_layer(kwargs, default='relu'):
act_layer = kwargs.pop('act_layer', default)
if isinstance(act_layer, str):
act_layer = get_act_layer(act_layer)
return act_layer
return get_act_layer(kwargs.pop('act_layer', default))
def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9):

@ -13,7 +13,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import SelectAdaptivePool2d, Linear, hard_sigmoid, make_divisible
from .layers import SelectAdaptivePool2d, Linear, make_divisible
from .efficientnet_blocks import SqueezeExcite, ConvBnAct
from .helpers import build_model_with_cfg
from .registry import register_model
@ -40,7 +40,7 @@ default_cfgs = {
}
_SE_LAYER = partial(SqueezeExcite, gate_fn=hard_sigmoid, divisor=4)
_SE_LAYER = partial(SqueezeExcite, gate_fn='hard_sigmoid', divisor=4)
class GhostModule(nn.Module):

@ -1,20 +1,26 @@
""" Activation Factory
Hacked together by / Copyright 2020 Ross Wightman
"""
from typing import Union, Callable, Type
from .activations import *
from .activations_jit import *
from .activations_me import *
from .config import is_exportable, is_scriptable, is_no_jit
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. This code
# will use native version if present. Eventually, the custom Swish layers will be removed
# and only native 'silu' will be used.
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
# Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used.
_has_silu = 'silu' in dir(torch.nn.functional)
_has_hardswish = 'hardswish' in dir(torch.nn.functional)
_has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional)
_has_mish = 'mish' in dir(torch.nn.functional)
_ACT_FN_DEFAULT = dict(
silu=F.silu if _has_silu else swish,
swish=F.silu if _has_silu else swish,
mish=mish,
mish=F.mish if _has_mish else mish,
relu=F.relu,
relu6=F.relu6,
leaky_relu=F.leaky_relu,
@ -24,33 +30,39 @@ _ACT_FN_DEFAULT = dict(
gelu=gelu,
sigmoid=sigmoid,
tanh=tanh,
hard_sigmoid=hard_sigmoid,
hard_swish=hard_swish,
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid,
hard_swish=F.hardswish if _has_hardswish else hard_swish,
hard_mish=hard_mish,
)
_ACT_FN_JIT = dict(
silu=F.silu if _has_silu else swish_jit,
swish=F.silu if _has_silu else swish_jit,
mish=mish_jit,
hard_sigmoid=hard_sigmoid_jit,
hard_swish=hard_swish_jit,
mish=F.mish if _has_mish else mish_jit,
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit,
hard_swish=F.hardswish if _has_hardswish else hard_swish_jit,
hard_mish=hard_mish_jit
)
_ACT_FN_ME = dict(
silu=F.silu if _has_silu else swish_me,
swish=F.silu if _has_silu else swish_me,
mish=mish_me,
hard_sigmoid=hard_sigmoid_me,
hard_swish=hard_swish_me,
mish=F.mish if _has_mish else mish_me,
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me,
hard_swish=F.hardswish if _has_hardswish else hard_swish_me,
hard_mish=hard_mish_me,
)
_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT)
for a in _ACT_FNS:
a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
a.setdefault('hardswish', a.get('hard_swish'))
_ACT_LAYER_DEFAULT = dict(
silu=nn.SiLU if _has_silu else Swish,
swish=nn.SiLU if _has_silu else Swish,
mish=Mish,
mish=nn.Mish if _has_mish else Mish,
relu=nn.ReLU,
relu6=nn.ReLU6,
leaky_relu=nn.LeakyReLU,
@ -61,37 +73,44 @@ _ACT_LAYER_DEFAULT = dict(
gelu=GELU,
sigmoid=Sigmoid,
tanh=Tanh,
hard_sigmoid=HardSigmoid,
hard_swish=HardSwish,
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid,
hard_swish=nn.Hardswish if _has_hardswish else HardSwish,
hard_mish=HardMish,
)
_ACT_LAYER_JIT = dict(
silu=nn.SiLU if _has_silu else SwishJit,
swish=nn.SiLU if _has_silu else SwishJit,
mish=MishJit,
hard_sigmoid=HardSigmoidJit,
hard_swish=HardSwishJit,
mish=nn.Mish if _has_mish else MishJit,
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit,
hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit,
hard_mish=HardMishJit
)
_ACT_LAYER_ME = dict(
silu=nn.SiLU if _has_silu else SwishMe,
swish=nn.SiLU if _has_silu else SwishMe,
mish=MishMe,
hard_sigmoid=HardSigmoidMe,
hard_swish=HardSwishMe,
mish=nn.Mish if _has_mish else MishMe,
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe,
hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe,
hard_mish=HardMishMe,
)
_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT)
for a in _ACT_LAYERS:
a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
a.setdefault('hardswish', a.get('hard_swish'))
def get_act_fn(name='relu'):
def get_act_fn(name: Union[Callable, str] = 'relu'):
""" Activation Function Factory
Fetching activation fns by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
"""
if not name:
return None
if isinstance(name, Callable):
return name
if not (is_no_jit() or is_exportable() or is_scriptable()):
# If not exporting or scripting the model, first look for a memory-efficient version with
# custom autograd, then fallback
@ -106,13 +125,15 @@ def get_act_fn(name='relu'):
return _ACT_FN_DEFAULT[name]
def get_act_layer(name='relu'):
def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
""" Activation Layer Factory
Fetching activation layers by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
"""
if not name:
return None
if isinstance(name, type):
return name
if not (is_no_jit() or is_exportable() or is_scriptable()):
if name in _ACT_LAYER_ME:
return _ACT_LAYER_ME[name]
@ -125,9 +146,8 @@ def get_act_layer(name='relu'):
return _ACT_LAYER_DEFAULT[name]
def create_act_layer(name, inplace=False, **kwargs):
def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs):
act_layer = get_act_layer(name)
if act_layer is not None:
return act_layer(inplace=inplace, **kwargs)
else:
if act_layer is None:
return None
return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs)

@ -42,7 +42,7 @@ class EffectiveSEModule(nn.Module):
def __init__(self, channels, gate_layer='hard_sigmoid'):
super(EffectiveSEModule, self).__init__()
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
self.gate = create_act_layer(gate_layer, inplace=True)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)

@ -33,7 +33,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import to_ntuple
from .layers import to_ntuple, get_act_layer
from .vision_transformer import trunc_normal_
from .registry import register_model
@ -443,12 +443,14 @@ class Levit(nn.Module):
mlp_ratio=2,
hybrid_backbone=None,
down_ops=None,
act_layer=nn.Hardswish,
attn_act_layer=nn.Hardswish,
act_layer='hard_swish',
attn_act_layer='hard_swish',
distillation=True,
use_conv=False,
drop_path=0):
super().__init__()
act_layer = get_act_layer(act_layer)
attn_act_layer = get_act_layer(attn_act_layer)
if isinstance(img_size, tuple):
# FIXME origin impl passes single img/res dim through whole hierarchy,
# not sure this model will be used enough to spend time fixing it.

Loading…
Cancel
Save