diff --git a/tests/test_layers.py b/tests/test_layers.py index 714cb444..508a6aae 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -8,10 +8,10 @@ from timm.models.layers import create_act_layer, get_act_layer, set_layer_config class MLP(nn.Module): - def __init__(self, act_layer="relu"): + def __init__(self, act_layer="relu", inplace=True): super(MLP, self).__init__() self.fc1 = nn.Linear(1000, 100) - self.act = create_act_layer(act_layer, inplace=True) + self.act = create_act_layer(act_layer, inplace=inplace) self.fc2 = nn.Linear(100, 10) def forward(self, x): @@ -21,14 +21,14 @@ class MLP(nn.Module): return x -def _run_act_layer_grad(act_type): +def _run_act_layer_grad(act_type, inplace=True): x = torch.rand(10, 1000) * 10 - m = MLP(act_layer=act_type) + m = MLP(act_layer=act_type, inplace=inplace) def _run(x, act_layer=''): if act_layer: # replace act layer if set - m.act = create_act_layer(act_layer, inplace=True) + m.act = create_act_layer(act_layer, inplace=inplace) out = m(x) l = (out - 0).pow(2).sum() return l @@ -58,7 +58,7 @@ def test_mish_grad(): def test_hard_sigmoid_grad(): for _ in range(100): - _run_act_layer_grad('hard_sigmoid') + _run_act_layer_grad('hard_sigmoid', inplace=None) def test_hard_swish_grad(): diff --git a/tests/test_models.py b/tests/test_models.py index 44cb3ba2..18298dff 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -110,7 +110,7 @@ def test_model_backward(model_name, batch_size): assert not torch.isnan(outputs).any(), 'Output included NaNs' -@pytest.mark.timeout(120) +@pytest.mark.timeout(300) @pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS)) @pytest.mark.parametrize('batch_size', [1]) def test_model_default_cfgs(model_name, batch_size): diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index 83b57beb..7853db0e 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn from torch.nn import functional as F -from .layers import create_conv2d, drop_path, make_divisible +from .layers import create_conv2d, drop_path, make_divisible, get_act_fn, create_act_layer from .layers.activations import sigmoid __all__ = [ @@ -36,9 +36,9 @@ class SqueezeExcite(nn.Module): reduced_chs = make_divisible(reduced_chs * se_ratio, divisor) act_layer = force_act_layer or act_layer self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) - self.act1 = act_layer(inplace=True) + self.act1 = create_act_layer(act_layer, inplace=True) self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) - self.gate_fn = gate_fn + self.gate_fn = get_act_fn(gate_fn) def forward(self, x): x_se = x.mean((2, 3), keepdim=True) diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index 30739454..57e2039b 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -50,10 +50,7 @@ def resolve_bn_args(kwargs): def resolve_act_layer(kwargs, default='relu'): - act_layer = kwargs.pop('act_layer', default) - if isinstance(act_layer, str): - act_layer = get_act_layer(act_layer) - return act_layer + return get_act_layer(kwargs.pop('act_layer', default)) def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9): diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index c132142a..1783ff7a 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .layers import SelectAdaptivePool2d, Linear, hard_sigmoid, make_divisible +from .layers import SelectAdaptivePool2d, Linear, make_divisible from .efficientnet_blocks import SqueezeExcite, ConvBnAct from .helpers import build_model_with_cfg from .registry import register_model @@ -40,7 +40,7 @@ default_cfgs = { } -_SE_LAYER = partial(SqueezeExcite, gate_fn=hard_sigmoid, divisor=4) +_SE_LAYER = partial(SqueezeExcite, gate_fn='hard_sigmoid', divisor=4) class GhostModule(nn.Module): diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py index 426c3681..aa557692 100644 --- a/timm/models/layers/create_act.py +++ b/timm/models/layers/create_act.py @@ -1,20 +1,26 @@ """ Activation Factory Hacked together by / Copyright 2020 Ross Wightman """ +from typing import Union, Callable, Type + from .activations import * from .activations_jit import * from .activations_me import * from .config import is_exportable, is_scriptable, is_no_jit -# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. This code -# will use native version if present. Eventually, the custom Swish layers will be removed -# and only native 'silu' will be used. +# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. +# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present. +# Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used. _has_silu = 'silu' in dir(torch.nn.functional) +_has_hardswish = 'hardswish' in dir(torch.nn.functional) +_has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional) +_has_mish = 'mish' in dir(torch.nn.functional) + _ACT_FN_DEFAULT = dict( silu=F.silu if _has_silu else swish, swish=F.silu if _has_silu else swish, - mish=mish, + mish=F.mish if _has_mish else mish, relu=F.relu, relu6=F.relu6, leaky_relu=F.leaky_relu, @@ -24,33 +30,39 @@ _ACT_FN_DEFAULT = dict( gelu=gelu, sigmoid=sigmoid, tanh=tanh, - hard_sigmoid=hard_sigmoid, - hard_swish=hard_swish, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid, + hard_swish=F.hardswish if _has_hardswish else hard_swish, hard_mish=hard_mish, ) _ACT_FN_JIT = dict( silu=F.silu if _has_silu else swish_jit, swish=F.silu if _has_silu else swish_jit, - mish=mish_jit, - hard_sigmoid=hard_sigmoid_jit, - hard_swish=hard_swish_jit, + mish=F.mish if _has_mish else mish_jit, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit, + hard_swish=F.hardswish if _has_hardswish else hard_swish_jit, hard_mish=hard_mish_jit ) _ACT_FN_ME = dict( silu=F.silu if _has_silu else swish_me, swish=F.silu if _has_silu else swish_me, - mish=mish_me, - hard_sigmoid=hard_sigmoid_me, - hard_swish=hard_swish_me, + mish=F.mish if _has_mish else mish_me, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me, + hard_swish=F.hardswish if _has_hardswish else hard_swish_me, hard_mish=hard_mish_me, ) +_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT) +for a in _ACT_FNS: + a.setdefault('hardsigmoid', a.get('hard_sigmoid')) + a.setdefault('hardswish', a.get('hard_swish')) + + _ACT_LAYER_DEFAULT = dict( silu=nn.SiLU if _has_silu else Swish, swish=nn.SiLU if _has_silu else Swish, - mish=Mish, + mish=nn.Mish if _has_mish else Mish, relu=nn.ReLU, relu6=nn.ReLU6, leaky_relu=nn.LeakyReLU, @@ -61,37 +73,44 @@ _ACT_LAYER_DEFAULT = dict( gelu=GELU, sigmoid=Sigmoid, tanh=Tanh, - hard_sigmoid=HardSigmoid, - hard_swish=HardSwish, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid, + hard_swish=nn.Hardswish if _has_hardswish else HardSwish, hard_mish=HardMish, ) _ACT_LAYER_JIT = dict( silu=nn.SiLU if _has_silu else SwishJit, swish=nn.SiLU if _has_silu else SwishJit, - mish=MishJit, - hard_sigmoid=HardSigmoidJit, - hard_swish=HardSwishJit, + mish=nn.Mish if _has_mish else MishJit, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit, + hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit, hard_mish=HardMishJit ) _ACT_LAYER_ME = dict( silu=nn.SiLU if _has_silu else SwishMe, swish=nn.SiLU if _has_silu else SwishMe, - mish=MishMe, - hard_sigmoid=HardSigmoidMe, - hard_swish=HardSwishMe, + mish=nn.Mish if _has_mish else MishMe, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe, + hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe, hard_mish=HardMishMe, ) +_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT) +for a in _ACT_LAYERS: + a.setdefault('hardsigmoid', a.get('hard_sigmoid')) + a.setdefault('hardswish', a.get('hard_swish')) + -def get_act_fn(name='relu'): +def get_act_fn(name: Union[Callable, str] = 'relu'): """ Activation Function Factory Fetching activation fns by name with this function allows export or torch script friendly functions to be returned dynamically based on current config. """ if not name: return None + if isinstance(name, Callable): + return name if not (is_no_jit() or is_exportable() or is_scriptable()): # If not exporting or scripting the model, first look for a memory-efficient version with # custom autograd, then fallback @@ -106,13 +125,15 @@ def get_act_fn(name='relu'): return _ACT_FN_DEFAULT[name] -def get_act_layer(name='relu'): +def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'): """ Activation Layer Factory Fetching activation layers by name with this function allows export or torch script friendly functions to be returned dynamically based on current config. """ if not name: return None + if isinstance(name, type): + return name if not (is_no_jit() or is_exportable() or is_scriptable()): if name in _ACT_LAYER_ME: return _ACT_LAYER_ME[name] @@ -125,9 +146,8 @@ def get_act_layer(name='relu'): return _ACT_LAYER_DEFAULT[name] -def create_act_layer(name, inplace=False, **kwargs): +def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs): act_layer = get_act_layer(name) - if act_layer is not None: - return act_layer(inplace=inplace, **kwargs) - else: + if act_layer is None: return None + return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs) diff --git a/timm/models/layers/se.py b/timm/models/layers/se.py index 54c0ef33..4354144d 100644 --- a/timm/models/layers/se.py +++ b/timm/models/layers/se.py @@ -42,7 +42,7 @@ class EffectiveSEModule(nn.Module): def __init__(self, channels, gate_layer='hard_sigmoid'): super(EffectiveSEModule, self).__init__() self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) - self.gate = create_act_layer(gate_layer, inplace=True) + self.gate = create_act_layer(gate_layer) def forward(self, x): x_se = x.mean((2, 3), keepdim=True) diff --git a/timm/models/levit.py b/timm/models/levit.py index 5019ee9a..2180254a 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -33,7 +33,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN from .helpers import build_model_with_cfg, overlay_external_default_cfg -from .layers import to_ntuple +from .layers import to_ntuple, get_act_layer from .vision_transformer import trunc_normal_ from .registry import register_model @@ -443,12 +443,14 @@ class Levit(nn.Module): mlp_ratio=2, hybrid_backbone=None, down_ops=None, - act_layer=nn.Hardswish, - attn_act_layer=nn.Hardswish, + act_layer='hard_swish', + attn_act_layer='hard_swish', distillation=True, use_conv=False, drop_path=0): super().__init__() + act_layer = get_act_layer(act_layer) + attn_act_layer = get_act_layer(attn_act_layer) if isinstance(img_size, tuple): # FIXME origin impl passes single img/res dim through whole hierarchy, # not sure this model will be used enough to spend time fixing it.