Fix #661, move hardswish out of default args for LeViT. Enable native torch support for hardswish, hardsigmoid, mish if present.

pull/669/head
Ross Wightman 4 years ago
parent 07d952c7a7
commit 9c78de8c02

@ -8,10 +8,10 @@ from timm.models.layers import create_act_layer, get_act_layer, set_layer_config
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, act_layer="relu"): def __init__(self, act_layer="relu", inplace=True):
super(MLP, self).__init__() super(MLP, self).__init__()
self.fc1 = nn.Linear(1000, 100) self.fc1 = nn.Linear(1000, 100)
self.act = create_act_layer(act_layer, inplace=True) self.act = create_act_layer(act_layer, inplace=inplace)
self.fc2 = nn.Linear(100, 10) self.fc2 = nn.Linear(100, 10)
def forward(self, x): def forward(self, x):
@ -21,14 +21,14 @@ class MLP(nn.Module):
return x return x
def _run_act_layer_grad(act_type): def _run_act_layer_grad(act_type, inplace=True):
x = torch.rand(10, 1000) * 10 x = torch.rand(10, 1000) * 10
m = MLP(act_layer=act_type) m = MLP(act_layer=act_type, inplace=inplace)
def _run(x, act_layer=''): def _run(x, act_layer=''):
if act_layer: if act_layer:
# replace act layer if set # replace act layer if set
m.act = create_act_layer(act_layer, inplace=True) m.act = create_act_layer(act_layer, inplace=inplace)
out = m(x) out = m(x)
l = (out - 0).pow(2).sum() l = (out - 0).pow(2).sum()
return l return l
@ -58,7 +58,7 @@ def test_mish_grad():
def test_hard_sigmoid_grad(): def test_hard_sigmoid_grad():
for _ in range(100): for _ in range(100):
_run_act_layer_grad('hard_sigmoid') _run_act_layer_grad('hard_sigmoid', inplace=None)
def test_hard_swish_grad(): def test_hard_swish_grad():

@ -110,7 +110,7 @@ def test_model_backward(model_name, batch_size):
assert not torch.isnan(outputs).any(), 'Output included NaNs' assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.timeout(120) @pytest.mark.timeout(300)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS)) @pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
def test_model_default_cfgs(model_name, batch_size): def test_model_default_cfgs(model_name, batch_size):

@ -7,7 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from .layers import create_conv2d, drop_path, make_divisible from .layers import create_conv2d, drop_path, make_divisible, get_act_fn, create_act_layer
from .layers.activations import sigmoid from .layers.activations import sigmoid
__all__ = [ __all__ = [
@ -36,9 +36,9 @@ class SqueezeExcite(nn.Module):
reduced_chs = make_divisible(reduced_chs * se_ratio, divisor) reduced_chs = make_divisible(reduced_chs * se_ratio, divisor)
act_layer = force_act_layer or act_layer act_layer = force_act_layer or act_layer
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
self.act1 = act_layer(inplace=True) self.act1 = create_act_layer(act_layer, inplace=True)
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
self.gate_fn = gate_fn self.gate_fn = get_act_fn(gate_fn)
def forward(self, x): def forward(self, x):
x_se = x.mean((2, 3), keepdim=True) x_se = x.mean((2, 3), keepdim=True)

@ -50,10 +50,7 @@ def resolve_bn_args(kwargs):
def resolve_act_layer(kwargs, default='relu'): def resolve_act_layer(kwargs, default='relu'):
act_layer = kwargs.pop('act_layer', default) return get_act_layer(kwargs.pop('act_layer', default))
if isinstance(act_layer, str):
act_layer = get_act_layer(act_layer)
return act_layer
def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9): def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9):

@ -13,7 +13,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import SelectAdaptivePool2d, Linear, hard_sigmoid, make_divisible from .layers import SelectAdaptivePool2d, Linear, make_divisible
from .efficientnet_blocks import SqueezeExcite, ConvBnAct from .efficientnet_blocks import SqueezeExcite, ConvBnAct
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .registry import register_model from .registry import register_model
@ -40,7 +40,7 @@ default_cfgs = {
} }
_SE_LAYER = partial(SqueezeExcite, gate_fn=hard_sigmoid, divisor=4) _SE_LAYER = partial(SqueezeExcite, gate_fn='hard_sigmoid', divisor=4)
class GhostModule(nn.Module): class GhostModule(nn.Module):

@ -1,20 +1,26 @@
""" Activation Factory """ Activation Factory
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
from typing import Union, Callable, Type
from .activations import * from .activations import *
from .activations_jit import * from .activations_jit import *
from .activations_me import * from .activations_me import *
from .config import is_exportable, is_scriptable, is_no_jit from .config import is_exportable, is_scriptable, is_no_jit
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. This code # PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
# will use native version if present. Eventually, the custom Swish layers will be removed # Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
# and only native 'silu' will be used. # Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used.
_has_silu = 'silu' in dir(torch.nn.functional) _has_silu = 'silu' in dir(torch.nn.functional)
_has_hardswish = 'hardswish' in dir(torch.nn.functional)
_has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional)
_has_mish = 'mish' in dir(torch.nn.functional)
_ACT_FN_DEFAULT = dict( _ACT_FN_DEFAULT = dict(
silu=F.silu if _has_silu else swish, silu=F.silu if _has_silu else swish,
swish=F.silu if _has_silu else swish, swish=F.silu if _has_silu else swish,
mish=mish, mish=F.mish if _has_mish else mish,
relu=F.relu, relu=F.relu,
relu6=F.relu6, relu6=F.relu6,
leaky_relu=F.leaky_relu, leaky_relu=F.leaky_relu,
@ -24,33 +30,39 @@ _ACT_FN_DEFAULT = dict(
gelu=gelu, gelu=gelu,
sigmoid=sigmoid, sigmoid=sigmoid,
tanh=tanh, tanh=tanh,
hard_sigmoid=hard_sigmoid, hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid,
hard_swish=hard_swish, hard_swish=F.hardswish if _has_hardswish else hard_swish,
hard_mish=hard_mish, hard_mish=hard_mish,
) )
_ACT_FN_JIT = dict( _ACT_FN_JIT = dict(
silu=F.silu if _has_silu else swish_jit, silu=F.silu if _has_silu else swish_jit,
swish=F.silu if _has_silu else swish_jit, swish=F.silu if _has_silu else swish_jit,
mish=mish_jit, mish=F.mish if _has_mish else mish_jit,
hard_sigmoid=hard_sigmoid_jit, hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit,
hard_swish=hard_swish_jit, hard_swish=F.hardswish if _has_hardswish else hard_swish_jit,
hard_mish=hard_mish_jit hard_mish=hard_mish_jit
) )
_ACT_FN_ME = dict( _ACT_FN_ME = dict(
silu=F.silu if _has_silu else swish_me, silu=F.silu if _has_silu else swish_me,
swish=F.silu if _has_silu else swish_me, swish=F.silu if _has_silu else swish_me,
mish=mish_me, mish=F.mish if _has_mish else mish_me,
hard_sigmoid=hard_sigmoid_me, hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me,
hard_swish=hard_swish_me, hard_swish=F.hardswish if _has_hardswish else hard_swish_me,
hard_mish=hard_mish_me, hard_mish=hard_mish_me,
) )
_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT)
for a in _ACT_FNS:
a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
a.setdefault('hardswish', a.get('hard_swish'))
_ACT_LAYER_DEFAULT = dict( _ACT_LAYER_DEFAULT = dict(
silu=nn.SiLU if _has_silu else Swish, silu=nn.SiLU if _has_silu else Swish,
swish=nn.SiLU if _has_silu else Swish, swish=nn.SiLU if _has_silu else Swish,
mish=Mish, mish=nn.Mish if _has_mish else Mish,
relu=nn.ReLU, relu=nn.ReLU,
relu6=nn.ReLU6, relu6=nn.ReLU6,
leaky_relu=nn.LeakyReLU, leaky_relu=nn.LeakyReLU,
@ -61,37 +73,44 @@ _ACT_LAYER_DEFAULT = dict(
gelu=GELU, gelu=GELU,
sigmoid=Sigmoid, sigmoid=Sigmoid,
tanh=Tanh, tanh=Tanh,
hard_sigmoid=HardSigmoid, hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid,
hard_swish=HardSwish, hard_swish=nn.Hardswish if _has_hardswish else HardSwish,
hard_mish=HardMish, hard_mish=HardMish,
) )
_ACT_LAYER_JIT = dict( _ACT_LAYER_JIT = dict(
silu=nn.SiLU if _has_silu else SwishJit, silu=nn.SiLU if _has_silu else SwishJit,
swish=nn.SiLU if _has_silu else SwishJit, swish=nn.SiLU if _has_silu else SwishJit,
mish=MishJit, mish=nn.Mish if _has_mish else MishJit,
hard_sigmoid=HardSigmoidJit, hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit,
hard_swish=HardSwishJit, hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit,
hard_mish=HardMishJit hard_mish=HardMishJit
) )
_ACT_LAYER_ME = dict( _ACT_LAYER_ME = dict(
silu=nn.SiLU if _has_silu else SwishMe, silu=nn.SiLU if _has_silu else SwishMe,
swish=nn.SiLU if _has_silu else SwishMe, swish=nn.SiLU if _has_silu else SwishMe,
mish=MishMe, mish=nn.Mish if _has_mish else MishMe,
hard_sigmoid=HardSigmoidMe, hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe,
hard_swish=HardSwishMe, hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe,
hard_mish=HardMishMe, hard_mish=HardMishMe,
) )
_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT)
for a in _ACT_LAYERS:
a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
a.setdefault('hardswish', a.get('hard_swish'))
def get_act_fn(name='relu'): def get_act_fn(name: Union[Callable, str] = 'relu'):
""" Activation Function Factory """ Activation Function Factory
Fetching activation fns by name with this function allows export or torch script friendly Fetching activation fns by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config. functions to be returned dynamically based on current config.
""" """
if not name: if not name:
return None return None
if isinstance(name, Callable):
return name
if not (is_no_jit() or is_exportable() or is_scriptable()): if not (is_no_jit() or is_exportable() or is_scriptable()):
# If not exporting or scripting the model, first look for a memory-efficient version with # If not exporting or scripting the model, first look for a memory-efficient version with
# custom autograd, then fallback # custom autograd, then fallback
@ -106,13 +125,15 @@ def get_act_fn(name='relu'):
return _ACT_FN_DEFAULT[name] return _ACT_FN_DEFAULT[name]
def get_act_layer(name='relu'): def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
""" Activation Layer Factory """ Activation Layer Factory
Fetching activation layers by name with this function allows export or torch script friendly Fetching activation layers by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config. functions to be returned dynamically based on current config.
""" """
if not name: if not name:
return None return None
if isinstance(name, type):
return name
if not (is_no_jit() or is_exportable() or is_scriptable()): if not (is_no_jit() or is_exportable() or is_scriptable()):
if name in _ACT_LAYER_ME: if name in _ACT_LAYER_ME:
return _ACT_LAYER_ME[name] return _ACT_LAYER_ME[name]
@ -125,9 +146,8 @@ def get_act_layer(name='relu'):
return _ACT_LAYER_DEFAULT[name] return _ACT_LAYER_DEFAULT[name]
def create_act_layer(name, inplace=False, **kwargs): def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs):
act_layer = get_act_layer(name) act_layer = get_act_layer(name)
if act_layer is not None: if act_layer is None:
return act_layer(inplace=inplace, **kwargs)
else:
return None return None
return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs)

@ -42,7 +42,7 @@ class EffectiveSEModule(nn.Module):
def __init__(self, channels, gate_layer='hard_sigmoid'): def __init__(self, channels, gate_layer='hard_sigmoid'):
super(EffectiveSEModule, self).__init__() super(EffectiveSEModule, self).__init__()
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
self.gate = create_act_layer(gate_layer, inplace=True) self.gate = create_act_layer(gate_layer)
def forward(self, x): def forward(self, x):
x_se = x.mean((2, 3), keepdim=True) x_se = x.mean((2, 3), keepdim=True)

@ -33,7 +33,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import to_ntuple from .layers import to_ntuple, get_act_layer
from .vision_transformer import trunc_normal_ from .vision_transformer import trunc_normal_
from .registry import register_model from .registry import register_model
@ -443,12 +443,14 @@ class Levit(nn.Module):
mlp_ratio=2, mlp_ratio=2,
hybrid_backbone=None, hybrid_backbone=None,
down_ops=None, down_ops=None,
act_layer=nn.Hardswish, act_layer='hard_swish',
attn_act_layer=nn.Hardswish, attn_act_layer='hard_swish',
distillation=True, distillation=True,
use_conv=False, use_conv=False,
drop_path=0): drop_path=0):
super().__init__() super().__init__()
act_layer = get_act_layer(act_layer)
attn_act_layer = get_act_layer(attn_act_layer)
if isinstance(img_size, tuple): if isinstance(img_size, tuple):
# FIXME origin impl passes single img/res dim through whole hierarchy, # FIXME origin impl passes single img/res dim through whole hierarchy,
# not sure this model will be used enough to spend time fixing it. # not sure this model will be used enough to spend time fixing it.

Loading…
Cancel
Save