Merge pull request #155 from rwightman/densenet_update_and_more
DenseNet updates, EvoNorms, VovNet, activation factory and more. Includes PR #142pull/179/head
commit
d1b5dddad1
|
|
|
|
@ -0,0 +1,71 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import platform
|
||||
import os
|
||||
|
||||
from timm.models.layers import create_act_layer, get_act_layer, set_layer_config
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, act_layer="relu"):
|
||||
super(MLP, self).__init__()
|
||||
self.fc1 = nn.Linear(1000, 100)
|
||||
self.act = create_act_layer(act_layer, inplace=True)
|
||||
self.fc2 = nn.Linear(100, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
def _run_act_layer_grad(act_type):
|
||||
x = torch.rand(10, 1000) * 10
|
||||
m = MLP(act_layer=act_type)
|
||||
|
||||
def _run(x, act_layer=''):
|
||||
if act_layer:
|
||||
# replace act layer if set
|
||||
m.act = create_act_layer(act_layer, inplace=True)
|
||||
out = m(x)
|
||||
l = (out - 0).pow(2).sum()
|
||||
return l
|
||||
|
||||
out_me = _run(x)
|
||||
|
||||
with set_layer_config(scriptable=True):
|
||||
out_jit = _run(x, act_type)
|
||||
|
||||
assert torch.isclose(out_jit, out_me)
|
||||
|
||||
with set_layer_config(no_jit=True):
|
||||
out_basic = _run(x, act_type)
|
||||
|
||||
assert torch.isclose(out_basic, out_jit)
|
||||
|
||||
|
||||
def test_swish_grad():
|
||||
for _ in range(100):
|
||||
_run_act_layer_grad('swish')
|
||||
|
||||
|
||||
def test_mish_grad():
|
||||
for _ in range(100):
|
||||
_run_act_layer_grad('mish')
|
||||
|
||||
|
||||
def test_hard_sigmoid_grad():
|
||||
for _ in range(100):
|
||||
_run_act_layer_grad('hard_sigmoid')
|
||||
|
||||
|
||||
def test_hard_swish_grad():
|
||||
for _ in range(100):
|
||||
_run_act_layer_grad('hard_swish')
|
||||
|
||||
|
||||
def test_hard_mish_grad():
|
||||
for _ in range(100):
|
||||
_run_act_layer_grad('hard_mish')
|
@ -1,2 +1,3 @@
|
||||
from .version import __version__
|
||||
from .models import create_model, list_models, is_model, list_modules, model_entrypoint
|
||||
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
|
||||
is_scriptable, is_exportable, set_scriptable, set_exportable
|
||||
|
@ -1,22 +1,29 @@
|
||||
from .padding import get_padding
|
||||
from .pool2d_same import AvgPool2dSame
|
||||
from .conv2d_same import Conv2dSame
|
||||
from .conv_bn_act import ConvBnAct
|
||||
from .mixed_conv2d import MixedConv2d
|
||||
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||
from .pool2d_same import create_pool2d
|
||||
from .create_conv2d import create_conv2d
|
||||
from .create_attn import create_attn
|
||||
from .selective_kernel import SelectiveKernelConv
|
||||
from .se import SEModule
|
||||
from .eca import EcaModule, CecaModule
|
||||
from .activations import *
|
||||
from .adaptive_avgmax_pool import \
|
||||
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
||||
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||
from .anti_aliasing import AntiAliasDownsampleLayer
|
||||
from .space_to_depth import SpaceToDepthModule
|
||||
from .blur_pool import BlurPool2d
|
||||
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
|
||||
set_layer_config
|
||||
from .conv2d_same import Conv2dSame
|
||||
from .conv_bn_act import ConvBnAct
|
||||
from .create_act import create_act_layer, get_act_layer, get_act_fn
|
||||
from .create_attn import create_attn
|
||||
from .create_conv2d import create_conv2d
|
||||
from .create_norm_act import create_norm_act, get_norm_act_layer
|
||||
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
||||
from .eca import EcaModule, CecaModule
|
||||
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
||||
from .inplace_abn import InplaceAbn
|
||||
from .mixed_conv2d import MixedConv2d
|
||||
from .norm_act import BatchNormAct2d
|
||||
from .padding import get_padding
|
||||
from .pool2d_same import AvgPool2dSame, create_pool2d
|
||||
from .se import SEModule
|
||||
from .selective_kernel import SelectiveKernelConv
|
||||
from .separable_conv import SeparableConv2d, SeparableConvBnAct
|
||||
from .space_to_depth import SpaceToDepthModule
|
||||
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||
from .weight_init import trunc_normal_
|
||||
|
@ -0,0 +1,90 @@
|
||||
""" Activations
|
||||
|
||||
A collection of jit-scripted activations fn and modules with a common interface so that they can
|
||||
easily be swapped. All have an `inplace` arg even if not used.
|
||||
|
||||
All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
|
||||
currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
|
||||
versions if they contain in-place ops.
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish_jit(x, inplace: bool = False):
|
||||
"""Swish - Described in: https://arxiv.org/abs/1710.05941
|
||||
"""
|
||||
return x.mul(x.sigmoid())
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def mish_jit(x, _inplace: bool = False):
|
||||
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
||||
"""
|
||||
return x.mul(F.softplus(x).tanh())
|
||||
|
||||
|
||||
class SwishJit(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(SwishJit, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return swish_jit(x)
|
||||
|
||||
|
||||
class MishJit(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(MishJit, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return mish_jit(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_sigmoid_jit(x, inplace: bool = False):
|
||||
# return F.relu6(x + 3.) / 6.
|
||||
return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
|
||||
|
||||
|
||||
class HardSigmoidJit(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(HardSigmoidJit, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return hard_sigmoid_jit(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_swish_jit(x, inplace: bool = False):
|
||||
# return x * (F.relu6(x + 3.) / 6)
|
||||
return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
|
||||
|
||||
|
||||
class HardSwishJit(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(HardSwishJit, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return hard_swish_jit(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_mish_jit(x, inplace: bool = False):
|
||||
""" Hard Mish
|
||||
Experimental, based on notes by Mish author Diganta Misra at
|
||||
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
|
||||
"""
|
||||
return 0.5 * x * (x + 2).clamp(min=0, max=2)
|
||||
|
||||
|
||||
class HardMishJit(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(HardMishJit, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return hard_mish_jit(x)
|
@ -0,0 +1,208 @@
|
||||
""" Activations (memory-efficient w/ custom autograd)
|
||||
|
||||
A collection of activations fn and modules with a common interface so that they can
|
||||
easily be swapped. All have an `inplace` arg even if not used.
|
||||
|
||||
These activations are not compatible with jit scripting or ONNX export of the model, please use either
|
||||
the JIT or basic versions of the activations.
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish_jit_fwd(x):
|
||||
return x.mul(torch.sigmoid(x))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish_jit_bwd(x, grad_output):
|
||||
x_sigmoid = torch.sigmoid(x)
|
||||
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
|
||||
|
||||
|
||||
class SwishJitAutoFn(torch.autograd.Function):
|
||||
""" torch.jit.script optimised Swish w/ memory-efficient checkpoint
|
||||
Inspired by conversation btw Jeremy Howard & Adam Pazske
|
||||
https://twitter.com/jeremyphoward/status/1188251041835315200
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return swish_jit_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors[0]
|
||||
return swish_jit_bwd(x, grad_output)
|
||||
|
||||
|
||||
def swish_me(x, inplace=False):
|
||||
return SwishJitAutoFn.apply(x)
|
||||
|
||||
|
||||
class SwishMe(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(SwishMe, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return SwishJitAutoFn.apply(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def mish_jit_fwd(x):
|
||||
return x.mul(torch.tanh(F.softplus(x)))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def mish_jit_bwd(x, grad_output):
|
||||
x_sigmoid = torch.sigmoid(x)
|
||||
x_tanh_sp = F.softplus(x).tanh()
|
||||
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
|
||||
|
||||
|
||||
class MishJitAutoFn(torch.autograd.Function):
|
||||
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
||||
A memory efficient, jit scripted variant of Mish
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return mish_jit_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors[0]
|
||||
return mish_jit_bwd(x, grad_output)
|
||||
|
||||
|
||||
def mish_me(x, inplace=False):
|
||||
return MishJitAutoFn.apply(x)
|
||||
|
||||
|
||||
class MishMe(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(MishMe, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return MishJitAutoFn.apply(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_sigmoid_jit_fwd(x, inplace: bool = False):
|
||||
return (x + 3).clamp(min=0, max=6).div(6.)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_sigmoid_jit_bwd(x, grad_output):
|
||||
m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
|
||||
return grad_output * m
|
||||
|
||||
|
||||
class HardSigmoidJitAutoFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return hard_sigmoid_jit_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors[0]
|
||||
return hard_sigmoid_jit_bwd(x, grad_output)
|
||||
|
||||
|
||||
def hard_sigmoid_me(x, inplace: bool = False):
|
||||
return HardSigmoidJitAutoFn.apply(x)
|
||||
|
||||
|
||||
class HardSigmoidMe(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(HardSigmoidMe, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return HardSigmoidJitAutoFn.apply(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_swish_jit_fwd(x):
|
||||
return x * (x + 3).clamp(min=0, max=6).div(6.)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_swish_jit_bwd(x, grad_output):
|
||||
m = torch.ones_like(x) * (x >= 3.)
|
||||
m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
|
||||
return grad_output * m
|
||||
|
||||
|
||||
class HardSwishJitAutoFn(torch.autograd.Function):
|
||||
"""A memory efficient, jit-scripted HardSwish activation"""
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return hard_swish_jit_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors[0]
|
||||
return hard_swish_jit_bwd(x, grad_output)
|
||||
|
||||
|
||||
def hard_swish_me(x, inplace=False):
|
||||
return HardSwishJitAutoFn.apply(x)
|
||||
|
||||
|
||||
class HardSwishMe(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(HardSwishMe, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return HardSwishJitAutoFn.apply(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_mish_jit_fwd(x):
|
||||
return 0.5 * x * (x + 2).clamp(min=0, max=2)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def hard_mish_jit_bwd(x, grad_output):
|
||||
m = torch.ones_like(x) * (x >= -2.)
|
||||
m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
|
||||
return grad_output * m
|
||||
|
||||
|
||||
class HardMishJitAutoFn(torch.autograd.Function):
|
||||
""" A memory efficient, jit scripted variant of Hard Mish
|
||||
Experimental, based on notes by Mish author Diganta Misra at
|
||||
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return hard_mish_jit_fwd(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors[0]
|
||||
return hard_mish_jit_bwd(x, grad_output)
|
||||
|
||||
|
||||
def hard_mish_me(x, inplace: bool = False):
|
||||
return HardMishJitAutoFn.apply(x)
|
||||
|
||||
|
||||
class HardMishMe(nn.Module):
|
||||
def __init__(self, inplace: bool = False):
|
||||
super(HardMishMe, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return HardMishJitAutoFn.apply(x)
|
||||
|
||||
|
||||
|
@ -0,0 +1,115 @@
|
||||
""" Model / Layer Config singleton state
|
||||
"""
|
||||
from typing import Any, Optional
|
||||
|
||||
__all__ = [
|
||||
'is_exportable', 'is_scriptable', 'is_no_jit',
|
||||
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config'
|
||||
]
|
||||
|
||||
# Set to True if prefer to have layers with no jit optimization (includes activations)
|
||||
_NO_JIT = False
|
||||
|
||||
# Set to True if prefer to have activation layers with no jit optimization
|
||||
# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
|
||||
# the jit flags so far are activations. This will change as more layers are updated and/or added.
|
||||
_NO_ACTIVATION_JIT = False
|
||||
|
||||
# Set to True if exporting a model with Same padding via ONNX
|
||||
_EXPORTABLE = False
|
||||
|
||||
# Set to True if wanting to use torch.jit.script on a model
|
||||
_SCRIPTABLE = False
|
||||
|
||||
|
||||
def is_no_jit():
|
||||
return _NO_JIT
|
||||
|
||||
|
||||
class set_no_jit:
|
||||
def __init__(self, mode: bool) -> None:
|
||||
global _NO_JIT
|
||||
self.prev = _NO_JIT
|
||||
_NO_JIT = mode
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, *args: Any) -> bool:
|
||||
global _NO_JIT
|
||||
_NO_JIT = self.prev
|
||||
return False
|
||||
|
||||
|
||||
def is_exportable():
|
||||
return _EXPORTABLE
|
||||
|
||||
|
||||
class set_exportable:
|
||||
def __init__(self, mode: bool) -> None:
|
||||
global _EXPORTABLE
|
||||
self.prev = _EXPORTABLE
|
||||
_EXPORTABLE = mode
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, *args: Any) -> bool:
|
||||
global _EXPORTABLE
|
||||
_EXPORTABLE = self.prev
|
||||
return False
|
||||
|
||||
|
||||
def is_scriptable():
|
||||
return _SCRIPTABLE
|
||||
|
||||
|
||||
class set_scriptable:
|
||||
def __init__(self, mode: bool) -> None:
|
||||
global _SCRIPTABLE
|
||||
self.prev = _SCRIPTABLE
|
||||
_SCRIPTABLE = mode
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, *args: Any) -> bool:
|
||||
global _SCRIPTABLE
|
||||
_SCRIPTABLE = self.prev
|
||||
return False
|
||||
|
||||
|
||||
class set_layer_config:
|
||||
""" Layer config context manager that allows setting all layer config flags at once.
|
||||
If a flag arg is None, it will not change the current value.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
scriptable: Optional[bool] = None,
|
||||
exportable: Optional[bool] = None,
|
||||
no_jit: Optional[bool] = None,
|
||||
no_activation_jit: Optional[bool] = None):
|
||||
global _SCRIPTABLE
|
||||
global _EXPORTABLE
|
||||
global _NO_JIT
|
||||
global _NO_ACTIVATION_JIT
|
||||
self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT
|
||||
if scriptable is not None:
|
||||
_SCRIPTABLE = scriptable
|
||||
if exportable is not None:
|
||||
_EXPORTABLE = exportable
|
||||
if no_jit is not None:
|
||||
_NO_JIT = no_jit
|
||||
if no_activation_jit is not None:
|
||||
_NO_ACTIVATION_JIT = no_activation_jit
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
def __exit__(self, *args: Any) -> bool:
|
||||
global _SCRIPTABLE
|
||||
global _EXPORTABLE
|
||||
global _NO_JIT
|
||||
global _NO_ACTIVATION_JIT
|
||||
_SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
|
||||
return False
|
@ -0,0 +1,114 @@
|
||||
from .activations import *
|
||||
from .activations_jit import *
|
||||
from .activations_me import *
|
||||
from .config import is_exportable, is_scriptable, is_no_jit
|
||||
|
||||
|
||||
_ACT_FN_DEFAULT = dict(
|
||||
swish=swish,
|
||||
mish=mish,
|
||||
relu=F.relu,
|
||||
relu6=F.relu6,
|
||||
leaky_relu=F.leaky_relu,
|
||||
elu=F.elu,
|
||||
prelu=F.prelu,
|
||||
celu=F.celu,
|
||||
selu=F.selu,
|
||||
gelu=F.gelu,
|
||||
sigmoid=sigmoid,
|
||||
tanh=tanh,
|
||||
hard_sigmoid=hard_sigmoid,
|
||||
hard_swish=hard_swish,
|
||||
hard_mish=hard_mish,
|
||||
)
|
||||
|
||||
_ACT_FN_JIT = dict(
|
||||
swish=swish_jit,
|
||||
mish=mish_jit,
|
||||
hard_sigmoid=hard_sigmoid_jit,
|
||||
hard_swish=hard_swish_jit,
|
||||
hard_mish=hard_mish_jit
|
||||
)
|
||||
|
||||
_ACT_FN_ME = dict(
|
||||
swish=swish_me,
|
||||
mish=mish_me,
|
||||
hard_sigmoid=hard_sigmoid_me,
|
||||
hard_swish=hard_swish_me,
|
||||
hard_mish=hard_mish_me,
|
||||
)
|
||||
|
||||
_ACT_LAYER_DEFAULT = dict(
|
||||
swish=Swish,
|
||||
mish=Mish,
|
||||
relu=nn.ReLU,
|
||||
relu6=nn.ReLU6,
|
||||
elu=nn.ELU,
|
||||
prelu=nn.PReLU,
|
||||
celu=nn.CELU,
|
||||
selu=nn.SELU,
|
||||
gelu=nn.GELU,
|
||||
sigmoid=Sigmoid,
|
||||
tanh=Tanh,
|
||||
hard_sigmoid=HardSigmoid,
|
||||
hard_swish=HardSwish,
|
||||
hard_mish=HardMish,
|
||||
)
|
||||
|
||||
_ACT_LAYER_JIT = dict(
|
||||
swish=SwishJit,
|
||||
mish=MishJit,
|
||||
hard_sigmoid=HardSigmoidJit,
|
||||
hard_swish=HardSwishJit,
|
||||
hard_mish=HardMishJit
|
||||
)
|
||||
|
||||
_ACT_LAYER_ME = dict(
|
||||
swish=SwishMe,
|
||||
mish=MishMe,
|
||||
hard_sigmoid=HardSigmoidMe,
|
||||
hard_swish=HardSwishMe,
|
||||
hard_mish=HardMishMe,
|
||||
)
|
||||
|
||||
|
||||
def get_act_fn(name='relu'):
|
||||
""" Activation Function Factory
|
||||
Fetching activation fns by name with this function allows export or torch script friendly
|
||||
functions to be returned dynamically based on current config.
|
||||
"""
|
||||
if not name:
|
||||
return None
|
||||
if not (is_no_jit() or is_exportable() or is_scriptable()):
|
||||
# If not exporting or scripting the model, first look for a memory-efficient version with
|
||||
# custom autograd, then fallback
|
||||
if name in _ACT_FN_ME:
|
||||
return _ACT_FN_ME[name]
|
||||
if not is_no_jit():
|
||||
if name in _ACT_FN_JIT:
|
||||
return _ACT_FN_JIT[name]
|
||||
return _ACT_FN_DEFAULT[name]
|
||||
|
||||
|
||||
def get_act_layer(name='relu'):
|
||||
""" Activation Layer Factory
|
||||
Fetching activation layers by name with this function allows export or torch script friendly
|
||||
functions to be returned dynamically based on current config.
|
||||
"""
|
||||
if not name:
|
||||
return None
|
||||
if not (is_no_jit() or is_exportable() or is_scriptable()):
|
||||
if name in _ACT_LAYER_ME:
|
||||
return _ACT_LAYER_ME[name]
|
||||
if not is_no_jit():
|
||||
if name in _ACT_LAYER_JIT:
|
||||
return _ACT_LAYER_JIT[name]
|
||||
return _ACT_LAYER_DEFAULT[name]
|
||||
|
||||
|
||||
def create_act_layer(name, inplace=False, **kwargs):
|
||||
act_layer = get_act_layer(name)
|
||||
if act_layer is not None:
|
||||
return act_layer(inplace=inplace, **kwargs)
|
||||
else:
|
||||
return None
|
@ -0,0 +1,64 @@
|
||||
import types
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
||||
from .norm_act import BatchNormAct2d, GroupNormAct
|
||||
from .inplace_abn import InplaceAbn
|
||||
|
||||
_NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn}
|
||||
|
||||
|
||||
def get_norm_act_layer(layer_class):
|
||||
layer_class = layer_class.replace('_', '').lower()
|
||||
if layer_class.startswith("batchnorm"):
|
||||
layer = BatchNormAct2d
|
||||
elif layer_class.startswith("groupnorm"):
|
||||
layer = GroupNormAct
|
||||
elif layer_class == "evonormbatch":
|
||||
layer = EvoNormBatch2d
|
||||
elif layer_class == "evonormsample":
|
||||
layer = EvoNormSample2d
|
||||
elif layer_class == "iabn" or layer_class == "inplaceabn":
|
||||
layer = InplaceAbn
|
||||
else:
|
||||
assert False, "Invalid norm_act layer (%s)" % layer_class
|
||||
return layer
|
||||
|
||||
|
||||
def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs):
|
||||
layer_parts = layer_type.split('-') # e.g. batchnorm-leaky_relu
|
||||
assert len(layer_parts) in (1, 2)
|
||||
layer = get_norm_act_layer(layer_parts[0])
|
||||
#activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection?
|
||||
layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
|
||||
if jit:
|
||||
layer_instance = torch.jit.script(layer_instance)
|
||||
return layer_instance
|
||||
|
||||
|
||||
def convert_norm_act_type(norm_layer, act_layer, norm_kwargs=None):
|
||||
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
|
||||
assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial))
|
||||
norm_act_args = norm_kwargs.copy() if norm_kwargs else {}
|
||||
if isinstance(norm_layer, str):
|
||||
norm_act_layer = get_norm_act_layer(norm_layer)
|
||||
elif norm_layer in _NORM_ACT_TYPES:
|
||||
norm_act_layer = norm_layer
|
||||
elif isinstance(norm_layer, (types.FunctionType, functools.partial)):
|
||||
# assuming this is a lambda/fn/bound partial that creates norm_act layer
|
||||
norm_act_layer = norm_layer
|
||||
else:
|
||||
type_name = norm_layer.__name__.lower()
|
||||
if type_name.startswith('batchnorm'):
|
||||
norm_act_layer = BatchNormAct2d
|
||||
elif type_name.startswith('groupnorm'):
|
||||
norm_act_layer = GroupNormAct
|
||||
else:
|
||||
assert False, f"No equivalent norm_act layer for {type_name}"
|
||||
# Must pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
|
||||
# Newer models will use `apply_act` and likely have `act_layer` arg bound to relevant NormAct types.
|
||||
norm_act_args.update(dict(act_layer=act_layer))
|
||||
return norm_act_layer, norm_act_args
|
@ -0,0 +1,81 @@
|
||||
"""EvoNormB0 (Batched) and EvoNormS0 (Sample) in PyTorch
|
||||
|
||||
An attempt at getting decent performing EvoNorms running in PyTorch.
|
||||
While currently faster than other impl, still quite a ways off the built-in BN
|
||||
in terms of memory usage and throughput (roughly 5x mem, 1/2 - 1/3x speed).
|
||||
|
||||
Still very much a WIP, fiddling with buffer usage, in-place/jit optimizations, and layouts.
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class EvoNormBatch2d(nn.Module):
|
||||
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None):
|
||||
super(EvoNormBatch2d, self).__init__()
|
||||
self.apply_act = apply_act # apply activation (non-linearity)
|
||||
self.momentum = momentum
|
||||
self.eps = eps
|
||||
param_shape = (1, num_features, 1, 1)
|
||||
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
|
||||
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
|
||||
if apply_act:
|
||||
self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True)
|
||||
self.register_buffer('running_var', torch.ones(1, num_features, 1, 1))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.ones_(self.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
if self.apply_act:
|
||||
nn.init.ones_(self.v)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.dim() == 4, 'expected 4D input'
|
||||
x_type = x.dtype
|
||||
if self.training:
|
||||
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
|
||||
self.running_var.copy_(self.momentum * var.detach() + (1 - self.momentum) * self.running_var)
|
||||
else:
|
||||
var = self.running_var
|
||||
|
||||
if self.apply_act:
|
||||
v = self.v.to(dtype=x_type)
|
||||
d = (x * v) + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type)
|
||||
d = d.max((var + self.eps).sqrt().to(dtype=x_type))
|
||||
x = x / d
|
||||
return x * self.weight + self.bias
|
||||
|
||||
|
||||
class EvoNormSample2d(nn.Module):
|
||||
def __init__(self, num_features, apply_act=True, groups=8, eps=1e-5, drop_block=None):
|
||||
super(EvoNormSample2d, self).__init__()
|
||||
self.apply_act = apply_act # apply activation (non-linearity)
|
||||
self.groups = groups
|
||||
self.eps = eps
|
||||
param_shape = (1, num_features, 1, 1)
|
||||
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
|
||||
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
|
||||
if apply_act:
|
||||
self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.ones_(self.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
if self.apply_act:
|
||||
nn.init.ones_(self.v)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.dim() == 4, 'expected 4D input'
|
||||
B, C, H, W = x.shape
|
||||
assert C % self.groups == 0
|
||||
if self.apply_act:
|
||||
n = (x * self.v).sigmoid().reshape(B, self.groups, -1)
|
||||
x = x.reshape(B, self.groups, -1)
|
||||
x = n / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt()
|
||||
x = x.reshape(B, C, H, W)
|
||||
return x * self.weight + self.bias
|
@ -0,0 +1,85 @@
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
try:
|
||||
from inplace_abn.functions import inplace_abn, inplace_abn_sync
|
||||
has_iabn = True
|
||||
except ImportError:
|
||||
has_iabn = False
|
||||
|
||||
def inplace_abn(x, weight, bias, running_mean, running_var,
|
||||
training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01):
|
||||
raise ImportError(
|
||||
"Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'")
|
||||
|
||||
def inplace_abn_sync(**kwargs):
|
||||
inplace_abn(**kwargs)
|
||||
|
||||
|
||||
class InplaceAbn(nn.Module):
|
||||
"""Activated Batch Normalization
|
||||
|
||||
This gathers a BatchNorm and an activation function in a single module
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_features : int
|
||||
Number of feature channels in the input and output.
|
||||
eps : float
|
||||
Small constant to prevent numerical issues.
|
||||
momentum : float
|
||||
Momentum factor applied to compute running statistics.
|
||||
affine : bool
|
||||
If `True` apply learned scale and shift transformation after normalization.
|
||||
act_layer : str or nn.Module type
|
||||
Name or type of the activation functions, one of: `leaky_relu`, `elu`
|
||||
act_param : float
|
||||
Negative slope for the `leaky_relu` activation.
|
||||
"""
|
||||
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True,
|
||||
act_layer="leaky_relu", act_param=0.01, drop_block=None,):
|
||||
super(InplaceAbn, self).__init__()
|
||||
self.num_features = num_features
|
||||
self.affine = affine
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
if apply_act:
|
||||
if isinstance(act_layer, str):
|
||||
assert act_layer in ('leaky_relu', 'elu', 'identity')
|
||||
self.act_name = act_layer
|
||||
else:
|
||||
# convert act layer passed as type to string
|
||||
if isinstance(act_layer, nn.ELU):
|
||||
self.act_name = 'elu'
|
||||
elif isinstance(act_layer, nn.LeakyReLU):
|
||||
self.act_name = 'leaky_relu'
|
||||
else:
|
||||
assert False, f'Invalid act layer {act_layer.__name__} for IABN'
|
||||
else:
|
||||
self.act_name = 'identity'
|
||||
self.act_param = act_param
|
||||
if self.affine:
|
||||
self.weight = nn.Parameter(torch.ones(num_features))
|
||||
self.bias = nn.Parameter(torch.zeros(num_features))
|
||||
else:
|
||||
self.register_parameter('weight', None)
|
||||
self.register_parameter('bias', None)
|
||||
self.register_buffer('running_mean', torch.zeros(num_features))
|
||||
self.register_buffer('running_var', torch.ones(num_features))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.constant_(self.running_mean, 0)
|
||||
nn.init.constant_(self.running_var, 1)
|
||||
if self.affine:
|
||||
nn.init.constant_(self.weight, 1)
|
||||
nn.init.constant_(self.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
output = inplace_abn(
|
||||
x, self.weight, self.bias, self.running_mean, self.running_var,
|
||||
self.training, self.momentum, self.eps, self.act_name, self.act_param)
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
return output
|
@ -0,0 +1,85 @@
|
||||
""" Normalization + Activation Layers
|
||||
"""
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .create_act import get_act_layer
|
||||
|
||||
|
||||
class BatchNormAct2d(nn.BatchNorm2d):
|
||||
"""BatchNorm + Activation
|
||||
|
||||
This module performs BatchNorm + Activation in a manner that will remain backwards
|
||||
compatible with weights trained with separate bn, act. This is why we inherit from BN
|
||||
instead of composing it as a .bn member.
|
||||
"""
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
|
||||
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None):
|
||||
super(BatchNormAct2d, self).__init__(
|
||||
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
|
||||
if isinstance(act_layer, str):
|
||||
act_layer = get_act_layer(act_layer)
|
||||
if act_layer is not None and apply_act:
|
||||
self.act = act_layer(inplace=inplace)
|
||||
else:
|
||||
self.act = None
|
||||
|
||||
def _forward_jit(self, x):
|
||||
""" A cut & paste of the contents of the PyTorch BatchNorm2d forward function
|
||||
"""
|
||||
# exponential_average_factor is self.momentum set to
|
||||
# (when it is available) only so that if gets updated
|
||||
# in ONNX graph when this node is exported to ONNX.
|
||||
if self.momentum is None:
|
||||
exponential_average_factor = 0.0
|
||||
else:
|
||||
exponential_average_factor = self.momentum
|
||||
|
||||
if self.training and self.track_running_stats:
|
||||
# TODO: if statement only here to tell the jit to skip emitting this when it is None
|
||||
if self.num_batches_tracked is not None:
|
||||
self.num_batches_tracked += 1
|
||||
if self.momentum is None: # use cumulative moving average
|
||||
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
||||
else: # use exponential moving average
|
||||
exponential_average_factor = self.momentum
|
||||
|
||||
x = F.batch_norm(
|
||||
x, self.running_mean, self.running_var, self.weight, self.bias,
|
||||
self.training or not self.track_running_stats,
|
||||
exponential_average_factor, self.eps)
|
||||
return x
|
||||
|
||||
@torch.jit.ignore
|
||||
def _forward_python(self, x):
|
||||
return super(BatchNormAct2d, self).forward(x)
|
||||
|
||||
def forward(self, x):
|
||||
# FIXME cannot call parent forward() and maintain jit.script compatibility?
|
||||
if torch.jit.is_scripting():
|
||||
x = self._forward_jit(x)
|
||||
else:
|
||||
x = self._forward_python(x)
|
||||
if self.act is not None:
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class GroupNormAct(nn.GroupNorm):
|
||||
|
||||
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True,
|
||||
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None):
|
||||
super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine)
|
||||
if isinstance(act_layer, str):
|
||||
act_layer = get_act_layer(act_layer)
|
||||
if act_layer is not None and apply_act:
|
||||
self.act = act_layer(inplace=inplace)
|
||||
else:
|
||||
self.act = None
|
||||
|
||||
def forward(self, x):
|
||||
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||||
if self.act is not None:
|
||||
x = self.act(x)
|
||||
return x
|
@ -0,0 +1,51 @@
|
||||
from torch import nn as nn
|
||||
|
||||
from .create_conv2d import create_conv2d
|
||||
from .create_norm_act import convert_norm_act_type
|
||||
|
||||
|
||||
class SeparableConvBnAct(nn.Module):
|
||||
""" Separable Conv w/ trailing Norm and Activation
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
|
||||
channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||
act_layer=nn.ReLU, apply_act=True, drop_block=None):
|
||||
super(SeparableConvBnAct, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
|
||||
self.conv_dw = create_conv2d(
|
||||
in_channels, int(in_channels * channel_multiplier), kernel_size,
|
||||
stride=stride, dilation=dilation, padding=padding, depthwise=True)
|
||||
|
||||
self.conv_pw = create_conv2d(
|
||||
int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
|
||||
|
||||
norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs)
|
||||
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_dw(x)
|
||||
x = self.conv_pw(x)
|
||||
if self.bn is not None:
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
class SeparableConv2d(nn.Module):
|
||||
""" Separable Conv
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
|
||||
channel_multiplier=1.0, pw_kernel_size=1):
|
||||
super(SeparableConv2d, self).__init__()
|
||||
|
||||
self.conv_dw = create_conv2d(
|
||||
in_channels, int(in_channels * channel_multiplier), kernel_size,
|
||||
stride=stride, dilation=dilation, padding=padding, depthwise=True)
|
||||
|
||||
self.conv_pw = create_conv2d(
|
||||
int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_dw(x)
|
||||
x = self.conv_pw(x)
|
||||
return x
|
@ -0,0 +1,414 @@
|
||||
""" VoVNet (V1 & V2)
|
||||
|
||||
Papers:
|
||||
* `An Energy and GPU-Computation Efficient Backbone Network` - https://arxiv.org/abs/1904.09730
|
||||
* `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
|
||||
|
||||
Looked at https://github.com/youngwanLEE/vovnet-detectron2 &
|
||||
https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py
|
||||
for some reference, rewrote most of the code.
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, SelectAdaptivePool2d, \
|
||||
create_attn, create_norm_act, get_norm_act_layer
|
||||
|
||||
|
||||
# model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 &
|
||||
# https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py
|
||||
model_cfgs = dict(
|
||||
vovnet39a=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=5,
|
||||
block_per_stage=[1, 1, 2, 2],
|
||||
residual=False,
|
||||
depthwise=False,
|
||||
attn='',
|
||||
),
|
||||
vovnet57a=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=5,
|
||||
block_per_stage=[1, 1, 4, 3],
|
||||
residual=False,
|
||||
depthwise=False,
|
||||
attn='',
|
||||
|
||||
),
|
||||
ese_vovnet19b_slim_dw=dict(
|
||||
stem_chs=[64, 64, 64],
|
||||
stage_conv_chs=[64, 80, 96, 112],
|
||||
stage_out_chs=[112, 256, 384, 512],
|
||||
layer_per_block=3,
|
||||
block_per_stage=[1, 1, 1, 1],
|
||||
residual=True,
|
||||
depthwise=True,
|
||||
attn='ese',
|
||||
|
||||
),
|
||||
ese_vovnet19b_dw=dict(
|
||||
stem_chs=[64, 64, 64],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=3,
|
||||
block_per_stage=[1, 1, 1, 1],
|
||||
residual=True,
|
||||
depthwise=True,
|
||||
attn='ese',
|
||||
),
|
||||
ese_vovnet19b_slim=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[64, 80, 96, 112],
|
||||
stage_out_chs=[112, 256, 384, 512],
|
||||
layer_per_block=3,
|
||||
block_per_stage=[1, 1, 1, 1],
|
||||
residual=True,
|
||||
depthwise=False,
|
||||
attn='ese',
|
||||
),
|
||||
ese_vovnet19b=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=3,
|
||||
block_per_stage=[1, 1, 1, 1],
|
||||
residual=True,
|
||||
depthwise=False,
|
||||
attn='ese',
|
||||
|
||||
),
|
||||
ese_vovnet39b=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=5,
|
||||
block_per_stage=[1, 1, 2, 2],
|
||||
residual=True,
|
||||
depthwise=False,
|
||||
attn='ese',
|
||||
),
|
||||
ese_vovnet57b=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=5,
|
||||
block_per_stage=[1, 1, 4, 3],
|
||||
residual=True,
|
||||
depthwise=False,
|
||||
attn='ese',
|
||||
|
||||
),
|
||||
ese_vovnet99b=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=5,
|
||||
block_per_stage=[1, 3, 9, 3],
|
||||
residual=True,
|
||||
depthwise=False,
|
||||
attn='ese',
|
||||
),
|
||||
eca_vovnet39b=dict(
|
||||
stem_chs=[64, 64, 128],
|
||||
stage_conv_chs=[128, 160, 192, 224],
|
||||
stage_out_chs=[256, 512, 768, 1024],
|
||||
layer_per_block=5,
|
||||
block_per_stage=[1, 1, 2, 2],
|
||||
residual=True,
|
||||
depthwise=False,
|
||||
attn='eca',
|
||||
),
|
||||
)
|
||||
model_cfgs['ese_vovnet39b_evos'] = model_cfgs['ese_vovnet39b']
|
||||
model_cfgs['ese_vovnet99b_iabn'] = model_cfgs['ese_vovnet99b']
|
||||
|
||||
|
||||
def _cfg(url=''):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = dict(
|
||||
vovnet39a=_cfg(url=''),
|
||||
vovnet57a=_cfg(url=''),
|
||||
ese_vovnet19b_slim_dw=_cfg(url=''),
|
||||
ese_vovnet19b_dw=_cfg(url=''),
|
||||
ese_vovnet19b_slim=_cfg(url=''),
|
||||
ese_vovnet39b=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet39b-f912fe73.pth'),
|
||||
ese_vovnet57b=_cfg(url=''),
|
||||
ese_vovnet99b=_cfg(url=''),
|
||||
eca_vovnet39b=_cfg(url=''),
|
||||
ese_vovnet39b_evos=_cfg(url=''),
|
||||
ese_vovnet99b_iabn=_cfg(url=''),
|
||||
)
|
||||
|
||||
|
||||
class SequentialAppendList(nn.Sequential):
|
||||
def __init__(self, *args):
|
||||
super(SequentialAppendList, self).__init__(*args)
|
||||
|
||||
def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor:
|
||||
for i, module in enumerate(self):
|
||||
if i == 0:
|
||||
concat_list.append(module(x))
|
||||
else:
|
||||
concat_list.append(module(concat_list[-1]))
|
||||
x = torch.cat(concat_list, dim=1)
|
||||
return x
|
||||
|
||||
|
||||
class OsaBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_chs, mid_chs, out_chs, layer_per_block, residual=False,
|
||||
depthwise=False, attn='', norm_layer=BatchNormAct2d):
|
||||
super(OsaBlock, self).__init__()
|
||||
|
||||
self.residual = residual
|
||||
self.depthwise = depthwise
|
||||
|
||||
next_in_chs = in_chs
|
||||
if self.depthwise and next_in_chs != mid_chs:
|
||||
assert not residual
|
||||
self.conv_reduction = ConvBnAct(next_in_chs, mid_chs, 1, norm_layer=norm_layer)
|
||||
else:
|
||||
self.conv_reduction = None
|
||||
|
||||
mid_convs = []
|
||||
for i in range(layer_per_block):
|
||||
if self.depthwise:
|
||||
conv = SeparableConvBnAct(mid_chs, mid_chs, norm_layer=norm_layer)
|
||||
else:
|
||||
conv = ConvBnAct(next_in_chs, mid_chs, 3, norm_layer=norm_layer)
|
||||
next_in_chs = mid_chs
|
||||
mid_convs.append(conv)
|
||||
self.conv_mid = SequentialAppendList(*mid_convs)
|
||||
|
||||
# feature aggregation
|
||||
next_in_chs = in_chs + layer_per_block * mid_chs
|
||||
self.conv_concat = ConvBnAct(next_in_chs, out_chs, norm_layer=norm_layer)
|
||||
|
||||
if attn:
|
||||
self.attn = create_attn(attn, out_chs)
|
||||
else:
|
||||
self.attn = None
|
||||
|
||||
def forward(self, x):
|
||||
output = [x]
|
||||
if self.conv_reduction is not None:
|
||||
x = self.conv_reduction(x)
|
||||
x = self.conv_mid(x, output)
|
||||
x = self.conv_concat(x)
|
||||
if self.attn is not None:
|
||||
x = self.attn(x)
|
||||
if self.residual:
|
||||
x = x + output[0]
|
||||
return x
|
||||
|
||||
|
||||
class OsaStage(nn.Module):
|
||||
|
||||
def __init__(self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block,
|
||||
downsample=True, residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d):
|
||||
super(OsaStage, self).__init__()
|
||||
|
||||
if downsample:
|
||||
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
|
||||
else:
|
||||
self.pool = None
|
||||
|
||||
blocks = []
|
||||
for i in range(block_per_stage):
|
||||
last_block = i == block_per_stage - 1
|
||||
blocks += [OsaBlock(
|
||||
in_chs if i == 0 else out_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0,
|
||||
depthwise=depthwise, attn=attn if last_block else '', norm_layer=norm_layer)
|
||||
]
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
def forward(self, x):
|
||||
if self.pool is not None:
|
||||
x = self.pool(x)
|
||||
x = self.blocks(x)
|
||||
return x
|
||||
|
||||
|
||||
class ClassifierHead(nn.Module):
|
||||
"""Head."""
|
||||
|
||||
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.):
|
||||
super(ClassifierHead, self).__init__()
|
||||
self.drop_rate = drop_rate
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
|
||||
if num_classes > 0:
|
||||
self.fc = nn.Linear(in_chs, num_classes, bias=True)
|
||||
else:
|
||||
self.fc = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.global_pool(x).flatten(1)
|
||||
if self.drop_rate:
|
||||
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
class VovNet(nn.Module):
|
||||
|
||||
def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4,
|
||||
norm_layer=BatchNormAct2d):
|
||||
""" VovNet (v2)
|
||||
"""
|
||||
super(VovNet, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
assert stem_stride in (4, 2)
|
||||
|
||||
stem_chs = cfg["stem_chs"]
|
||||
stage_conv_chs = cfg["stage_conv_chs"]
|
||||
stage_out_chs = cfg["stage_out_chs"]
|
||||
block_per_stage = cfg["block_per_stage"]
|
||||
layer_per_block = cfg["layer_per_block"]
|
||||
|
||||
# Stem module
|
||||
last_stem_stride = stem_stride // 2
|
||||
conv_type = SeparableConvBnAct if cfg["depthwise"] else ConvBnAct
|
||||
self.stem = nn.Sequential(*[
|
||||
ConvBnAct(in_chans, stem_chs[0], 3, stride=2, norm_layer=norm_layer),
|
||||
conv_type(stem_chs[0], stem_chs[1], 3, stride=1, norm_layer=norm_layer),
|
||||
conv_type(stem_chs[1], stem_chs[2], 3, stride=last_stem_stride, norm_layer=norm_layer),
|
||||
])
|
||||
|
||||
# OSA stages
|
||||
in_ch_list = stem_chs[-1:] + stage_out_chs[:-1]
|
||||
stage_args = dict(
|
||||
residual=cfg["residual"], depthwise=cfg["depthwise"], attn=cfg["attn"], norm_layer=norm_layer)
|
||||
stages = []
|
||||
for i in range(4): # num_stages
|
||||
downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4
|
||||
stages += [OsaStage(
|
||||
in_ch_list[i], stage_conv_chs[i], stage_out_chs[i], block_per_stage[i], layer_per_block,
|
||||
downsample=downsample, **stage_args)
|
||||
]
|
||||
self.num_features = stage_out_chs[i]
|
||||
self.stages = nn.Sequential(*stages)
|
||||
|
||||
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
||||
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1.)
|
||||
nn.init.constant_(m.bias, 0.)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
return self.stages(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
return self.head(x)
|
||||
|
||||
|
||||
def _vovnet(variant, pretrained=False, **kwargs):
|
||||
load_strict = True
|
||||
model_class = VovNet
|
||||
if kwargs.pop('features_only', False):
|
||||
assert False, 'Not Implemented' # TODO
|
||||
load_strict = False
|
||||
kwargs.pop('num_classes', 0)
|
||||
model_cfg = model_cfgs[variant]
|
||||
default_cfg = default_cfgs[variant]
|
||||
model = model_class(model_cfg, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, default_cfg,
|
||||
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def vovnet39a(pretrained=False, **kwargs):
|
||||
return _vovnet('vovnet39a', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def vovnet57a(pretrained=False, **kwargs):
|
||||
return _vovnet('vovnet57a', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def ese_vovnet19b_slim_dw(pretrained=False, **kwargs):
|
||||
return _vovnet('ese_vovnet19b_slim_dw', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def ese_vovnet19b_dw(pretrained=False, **kwargs):
|
||||
return _vovnet('ese_vovnet19b_dw', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def ese_vovnet19b_slim(pretrained=False, **kwargs):
|
||||
return _vovnet('ese_vovnet19b_slim', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def ese_vovnet39b(pretrained=False, **kwargs):
|
||||
return _vovnet('ese_vovnet39b', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def ese_vovnet57b(pretrained=False, **kwargs):
|
||||
return _vovnet('ese_vovnet57b', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def ese_vovnet99b(pretrained=False, **kwargs):
|
||||
return _vovnet('ese_vovnet99b', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def eca_vovnet39b(pretrained=False, **kwargs):
|
||||
return _vovnet('eca_vovnet39b', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
# Experimental Models
|
||||
|
||||
@register_model
|
||||
def ese_vovnet39b_evos(pretrained=False, **kwargs):
|
||||
def norm_act_fn(num_features, **kwargs):
|
||||
return create_norm_act('EvoNormSample', num_features, jit=False, **kwargs)
|
||||
return _vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs)
|
||||
|
||||
@register_model
|
||||
def ese_vovnet99b_iabn(pretrained=False, **kwargs):
|
||||
norm_layer = get_norm_act_layer('iabn')
|
||||
return _vovnet('ese_vovnet99b_iabn', pretrained=pretrained, norm_layer=norm_layer, **kwargs)
|
@ -1 +1 @@
|
||||
__version__ = '0.1.26'
|
||||
__version__ = '0.1.28'
|
||||
|
Loading…
Reference in new issue