* refactor activations into basic PyTorch, jit scripted, and memory efficient custom auto * implement hard-mish, better grad for hard-swish * add initial VovNet V1/V2 impl, fix #151 * VovNet and DenseNet first models to use NormAct layers (support BatchNormAct2d, EvoNorm, InplaceIABN) * Wrap IABN for any models that use it * make more models torchscript compatible (DPN, PNasNet, Res2Net, SelecSLS) and add testspull/155/head
parent
ff94ffce61
commit
eb7653614f
@ -1,2 +1,3 @@
|
|||||||
from .version import __version__
|
from .version import __version__
|
||||||
from .models import create_model, list_models, is_model, list_modules, model_entrypoint
|
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
|
||||||
|
is_scriptable, is_exportable, set_scriptable, set_exportable
|
||||||
|
@ -1,25 +1,28 @@
|
|||||||
from .padding import get_padding
|
|
||||||
from .pool2d_same import AvgPool2dSame
|
|
||||||
from .conv2d_same import Conv2dSame
|
|
||||||
from .conv_bn_act import ConvBnAct
|
|
||||||
from .mixed_conv2d import MixedConv2d
|
|
||||||
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
|
||||||
from .pool2d_same import create_pool2d
|
|
||||||
from .create_conv2d import create_conv2d
|
|
||||||
from .create_attn import create_attn
|
|
||||||
from .selective_kernel import SelectiveKernelConv
|
|
||||||
from .se import SEModule
|
|
||||||
from .eca import EcaModule, CecaModule
|
|
||||||
from .activations import *
|
from .activations import *
|
||||||
from .adaptive_avgmax_pool import \
|
from .adaptive_avgmax_pool import \
|
||||||
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
||||||
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
|
||||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
|
||||||
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
|
||||||
from .anti_aliasing import AntiAliasDownsampleLayer
|
from .anti_aliasing import AntiAliasDownsampleLayer
|
||||||
from .space_to_depth import SpaceToDepthModule
|
|
||||||
from .blur_pool import BlurPool2d
|
from .blur_pool import BlurPool2d
|
||||||
from .norm_act import BatchNormAct2d
|
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||||
|
from .config import is_exportable, is_scriptable, set_exportable, set_scriptable, is_no_jit, set_no_jit
|
||||||
|
from .conv2d_same import Conv2dSame
|
||||||
|
from .conv_bn_act import ConvBnAct
|
||||||
|
from .create_act import create_act_layer, get_act_layer, get_act_fn
|
||||||
|
from .create_attn import create_attn
|
||||||
|
from .create_conv2d import create_conv2d
|
||||||
|
from .create_norm_act import create_norm_act, get_norm_act_layer
|
||||||
|
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
||||||
|
from .eca import EcaModule, CecaModule
|
||||||
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
||||||
from .create_norm_act import create_norm_act
|
from .inplace_abn import InplaceAbn
|
||||||
|
from .mixed_conv2d import MixedConv2d
|
||||||
|
from .norm_act import BatchNormAct2d
|
||||||
|
from .padding import get_padding
|
||||||
|
from .pool2d_same import AvgPool2dSame, create_pool2d
|
||||||
|
from .se import SEModule
|
||||||
|
from .selective_kernel import SelectiveKernelConv
|
||||||
|
from .separable_conv import SeparableConv2d, SeparableConvBnAct
|
||||||
|
from .space_to_depth import SpaceToDepthModule
|
||||||
|
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||||
|
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||||
from .weight_init import trunc_normal_
|
from .weight_init import trunc_normal_
|
||||||
|
@ -0,0 +1,90 @@
|
|||||||
|
""" Activations
|
||||||
|
|
||||||
|
A collection of jit-scripted activations fn and modules with a common interface so that they can
|
||||||
|
easily be swapped. All have an `inplace` arg even if not used.
|
||||||
|
|
||||||
|
All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
|
||||||
|
currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
|
||||||
|
versions if they contain in-place ops.
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def swish_jit(x, inplace: bool = False):
|
||||||
|
"""Swish - Described in: https://arxiv.org/abs/1710.05941
|
||||||
|
"""
|
||||||
|
return x.mul(x.sigmoid())
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def mish_jit(x, _inplace: bool = False):
|
||||||
|
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
||||||
|
"""
|
||||||
|
return x.mul(F.softplus(x).tanh())
|
||||||
|
|
||||||
|
|
||||||
|
class SwishJit(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(SwishJit, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return swish_jit(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MishJit(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(MishJit, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return mish_jit(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def hard_sigmoid_jit(x, inplace: bool = False):
|
||||||
|
# return F.relu6(x + 3.) / 6.
|
||||||
|
return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
|
||||||
|
|
||||||
|
|
||||||
|
class HardSigmoidJit(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(HardSigmoidJit, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return hard_sigmoid_jit(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def hard_swish_jit(x, inplace: bool = False):
|
||||||
|
# return x * (F.relu6(x + 3.) / 6)
|
||||||
|
return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
|
||||||
|
|
||||||
|
|
||||||
|
class HardSwishJit(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(HardSwishJit, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return hard_swish_jit(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def hard_mish_jit(x, inplace: bool = False):
|
||||||
|
""" Hard Mish
|
||||||
|
Experimental, based on notes by Mish author Diganta Misra at
|
||||||
|
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
|
||||||
|
"""
|
||||||
|
return 0.5 * x * (x + 2).clamp(min=0, max=2)
|
||||||
|
|
||||||
|
|
||||||
|
class HardMishJit(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(HardMishJit, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return hard_mish_jit(x)
|
@ -0,0 +1,208 @@
|
|||||||
|
""" Activations (memory-efficient w/ custom autograd)
|
||||||
|
|
||||||
|
A collection of activations fn and modules with a common interface so that they can
|
||||||
|
easily be swapped. All have an `inplace` arg even if not used.
|
||||||
|
|
||||||
|
These activations are not compatible with jit scripting or ONNX export of the model, please use either
|
||||||
|
the JIT or basic versions of the activations.
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def swish_jit_fwd(x):
|
||||||
|
return x.mul(torch.sigmoid(x))
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def swish_jit_bwd(x, grad_output):
|
||||||
|
x_sigmoid = torch.sigmoid(x)
|
||||||
|
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
|
||||||
|
|
||||||
|
|
||||||
|
class SwishJitAutoFn(torch.autograd.Function):
|
||||||
|
""" torch.jit.script optimised Swish w/ memory-efficient checkpoint
|
||||||
|
Inspired by conversation btw Jeremy Howard & Adam Pazske
|
||||||
|
https://twitter.com/jeremyphoward/status/1188251041835315200
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
return swish_jit_fwd(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
x = ctx.saved_tensors[0]
|
||||||
|
return swish_jit_bwd(x, grad_output)
|
||||||
|
|
||||||
|
|
||||||
|
def swish_me(x, inplace=False):
|
||||||
|
return SwishJitAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
class SwishMe(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(SwishMe, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return SwishJitAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def mish_jit_fwd(x):
|
||||||
|
return x.mul(torch.tanh(F.softplus(x)))
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def mish_jit_bwd(x, grad_output):
|
||||||
|
x_sigmoid = torch.sigmoid(x)
|
||||||
|
x_tanh_sp = F.softplus(x).tanh()
|
||||||
|
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
|
||||||
|
|
||||||
|
|
||||||
|
class MishJitAutoFn(torch.autograd.Function):
|
||||||
|
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
||||||
|
A memory efficient, jit scripted variant of Mish
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
return mish_jit_fwd(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
x = ctx.saved_tensors[0]
|
||||||
|
return mish_jit_bwd(x, grad_output)
|
||||||
|
|
||||||
|
|
||||||
|
def mish_me(x, inplace=False):
|
||||||
|
return MishJitAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MishMe(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(MishMe, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return MishJitAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def hard_sigmoid_jit_fwd(x, inplace: bool = False):
|
||||||
|
return (x + 3).clamp(min=0, max=6).div(6.)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def hard_sigmoid_jit_bwd(x, grad_output):
|
||||||
|
m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
|
||||||
|
return grad_output * m
|
||||||
|
|
||||||
|
|
||||||
|
class HardSigmoidJitAutoFn(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
return hard_sigmoid_jit_fwd(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
x = ctx.saved_tensors[0]
|
||||||
|
return hard_sigmoid_jit_bwd(x, grad_output)
|
||||||
|
|
||||||
|
|
||||||
|
def hard_sigmoid_me(x, inplace: bool = False):
|
||||||
|
return HardSigmoidJitAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
class HardSigmoidMe(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(HardSigmoidMe, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return HardSigmoidJitAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def hard_swish_jit_fwd(x):
|
||||||
|
return x * (x + 3).clamp(min=0, max=6).div(6.)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def hard_swish_jit_bwd(x, grad_output):
|
||||||
|
m = torch.ones_like(x) * (x >= 3.)
|
||||||
|
m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
|
||||||
|
return grad_output * m
|
||||||
|
|
||||||
|
|
||||||
|
class HardSwishJitAutoFn(torch.autograd.Function):
|
||||||
|
"""A memory efficient, jit-scripted HardSwish activation"""
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
return hard_swish_jit_fwd(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
x = ctx.saved_tensors[0]
|
||||||
|
return hard_swish_jit_bwd(x, grad_output)
|
||||||
|
|
||||||
|
|
||||||
|
def hard_swish_me(x, inplace=False):
|
||||||
|
return HardSwishJitAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
class HardSwishMe(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(HardSwishMe, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return HardSwishJitAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def hard_mish_jit_fwd(x):
|
||||||
|
return 0.5 * x * (x + 2).clamp(min=0, max=2)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def hard_mish_jit_bwd(x, grad_output):
|
||||||
|
m = torch.ones_like(x) * (x >= -2.)
|
||||||
|
m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
|
||||||
|
return grad_output * m
|
||||||
|
|
||||||
|
|
||||||
|
class HardMishJitAutoFn(torch.autograd.Function):
|
||||||
|
""" A memory efficient, jit scripted variant of Hard Mish
|
||||||
|
Experimental, based on notes by Mish author Diganta Misra at
|
||||||
|
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
return mish_jit_fwd(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
x = ctx.saved_tensors[0]
|
||||||
|
return mish_jit_bwd(x, grad_output)
|
||||||
|
|
||||||
|
|
||||||
|
def hard_mish_me(x, inplace: bool = False):
|
||||||
|
return HardMishJitAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
class HardMishMe(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(HardMishMe, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return HardMishJitAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,74 @@
|
|||||||
|
""" Model / Layer Config Singleton
|
||||||
|
"""
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
__all__ = ['is_exportable', 'is_scriptable', 'set_exportable', 'set_scriptable', 'is_no_jit', 'set_no_jit']
|
||||||
|
|
||||||
|
# Set to True if prefer to have layers with no jit optimization (includes activations)
|
||||||
|
_NO_JIT = False
|
||||||
|
|
||||||
|
# Set to True if prefer to have activation layers with no jit optimization
|
||||||
|
_NO_ACTIVATION_JIT = False
|
||||||
|
|
||||||
|
# Set to True if exporting a model with Same padding via ONNX
|
||||||
|
_EXPORTABLE = False
|
||||||
|
|
||||||
|
# Set to True if wanting to use torch.jit.script on a model
|
||||||
|
_SCRIPTABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def is_no_jit():
|
||||||
|
return _NO_JIT
|
||||||
|
|
||||||
|
|
||||||
|
class set_no_jit:
|
||||||
|
def __init__(self, mode: bool) -> None:
|
||||||
|
global _NO_JIT
|
||||||
|
self.prev = _NO_JIT
|
||||||
|
_NO_JIT = mode
|
||||||
|
|
||||||
|
def __enter__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, *args: Any) -> bool:
|
||||||
|
global _NO_JIT
|
||||||
|
_NO_JIT = self.prev
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_exportable():
|
||||||
|
return _EXPORTABLE
|
||||||
|
|
||||||
|
|
||||||
|
class set_exportable:
|
||||||
|
def __init__(self, mode: bool) -> None:
|
||||||
|
global _EXPORTABLE
|
||||||
|
self.prev = _EXPORTABLE
|
||||||
|
_EXPORTABLE = mode
|
||||||
|
|
||||||
|
def __enter__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, *args: Any) -> bool:
|
||||||
|
global _EXPORTABLE
|
||||||
|
_EXPORTABLE = self.prev
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_scriptable():
|
||||||
|
return _SCRIPTABLE
|
||||||
|
|
||||||
|
|
||||||
|
class set_scriptable:
|
||||||
|
def __init__(self, mode: bool) -> None:
|
||||||
|
global _SCRIPTABLE
|
||||||
|
self.prev = _SCRIPTABLE
|
||||||
|
_SCRIPTABLE = mode
|
||||||
|
|
||||||
|
def __enter__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, *args: Any) -> bool:
|
||||||
|
global _SCRIPTABLE
|
||||||
|
_SCRIPTABLE = self.prev
|
||||||
|
return False
|
@ -0,0 +1,103 @@
|
|||||||
|
from .activations import *
|
||||||
|
from .activations_jit import *
|
||||||
|
from .activations_me import *
|
||||||
|
from .config import is_exportable, is_scriptable, is_no_jit
|
||||||
|
|
||||||
|
|
||||||
|
_ACT_FN_DEFAULT = dict(
|
||||||
|
swish=swish,
|
||||||
|
mish=mish,
|
||||||
|
relu=F.relu,
|
||||||
|
relu6=F.relu6,
|
||||||
|
sigmoid=sigmoid,
|
||||||
|
tanh=tanh,
|
||||||
|
hard_sigmoid=hard_sigmoid,
|
||||||
|
hard_swish=hard_swish,
|
||||||
|
hard_mish=hard_mish,
|
||||||
|
)
|
||||||
|
|
||||||
|
_ACT_FN_JIT = dict(
|
||||||
|
swish=swish_jit,
|
||||||
|
mish=mish_jit,
|
||||||
|
hard_sigmoid=hard_sigmoid_jit,
|
||||||
|
hard_swish=hard_swish_jit,
|
||||||
|
hard_mish=hard_mish_jit
|
||||||
|
)
|
||||||
|
|
||||||
|
_ACT_FN_ME = dict(
|
||||||
|
swish=swish_me,
|
||||||
|
mish=mish_me,
|
||||||
|
hard_sigmoid=hard_sigmoid_me,
|
||||||
|
hard_swish=hard_swish_me,
|
||||||
|
hard_mish=hard_mish_me,
|
||||||
|
)
|
||||||
|
|
||||||
|
_ACT_LAYER_DEFAULT = dict(
|
||||||
|
swish=Swish,
|
||||||
|
mish=Mish,
|
||||||
|
relu=nn.ReLU,
|
||||||
|
relu6=nn.ReLU6,
|
||||||
|
sigmoid=Sigmoid,
|
||||||
|
tanh=Tanh,
|
||||||
|
hard_sigmoid=HardSigmoid,
|
||||||
|
hard_swish=HardSwish,
|
||||||
|
hard_mish=HardMish,
|
||||||
|
)
|
||||||
|
|
||||||
|
_ACT_LAYER_JIT = dict(
|
||||||
|
swish=SwishJit,
|
||||||
|
mish=MishJit,
|
||||||
|
hard_sigmoid=HardSigmoidJit,
|
||||||
|
hard_swish=HardSwishJit,
|
||||||
|
hard_mish=HardMishJit
|
||||||
|
)
|
||||||
|
|
||||||
|
_ACT_LAYER_ME = dict(
|
||||||
|
swish=SwishMe,
|
||||||
|
mish=MishMe,
|
||||||
|
hard_sigmoid=HardSigmoidMe,
|
||||||
|
hard_swish=HardSwishMe,
|
||||||
|
hard_mish=HardMishMe,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_act_fn(name='relu'):
|
||||||
|
""" Activation Function Factory
|
||||||
|
Fetching activation fns by name with this function allows export or torch script friendly
|
||||||
|
functions to be returned dynamically based on current config.
|
||||||
|
"""
|
||||||
|
if not name:
|
||||||
|
return None
|
||||||
|
if not (is_no_jit() or is_exportable() or is_scriptable()):
|
||||||
|
# If not exporting or scripting the model, first look for a memory-efficient version with
|
||||||
|
# custom autograd, then fallback
|
||||||
|
if name in _ACT_FN_ME:
|
||||||
|
return _ACT_FN_ME[name]
|
||||||
|
if not is_no_jit():
|
||||||
|
if name in _ACT_FN_JIT:
|
||||||
|
return _ACT_FN_JIT[name]
|
||||||
|
return _ACT_FN_DEFAULT[name]
|
||||||
|
|
||||||
|
|
||||||
|
def get_act_layer(name='relu'):
|
||||||
|
""" Activation Layer Factory
|
||||||
|
Fetching activation layers by name with this function allows export or torch script friendly
|
||||||
|
functions to be returned dynamically based on current config.
|
||||||
|
"""
|
||||||
|
if not name:
|
||||||
|
return None
|
||||||
|
if not (is_no_jit() or is_exportable() or is_scriptable()):
|
||||||
|
if name in _ACT_LAYER_ME:
|
||||||
|
return _ACT_LAYER_ME[name]
|
||||||
|
if not is_no_jit():
|
||||||
|
if name in _ACT_LAYER_JIT:
|
||||||
|
return _ACT_LAYER_JIT[name]
|
||||||
|
return _ACT_LAYER_DEFAULT[name]
|
||||||
|
|
||||||
|
|
||||||
|
def create_act_layer(name, inplace=False, **kwargs):
|
||||||
|
act_layer = get_act_layer(name)
|
||||||
|
if act_layer is not None:
|
||||||
|
return act_layer(inplace=inplace, **kwargs)
|
||||||
|
else:
|
||||||
|
return None
|
@ -1,37 +1,64 @@
|
|||||||
|
import types
|
||||||
|
import functools
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
||||||
from .norm_act import BatchNormAct2d
|
from .norm_act import BatchNormAct2d, GroupNormAct
|
||||||
try:
|
from .inplace_abn import InplaceAbn
|
||||||
from inplace_abn import InPlaceABN
|
|
||||||
has_iabn = True
|
|
||||||
except ImportError:
|
|
||||||
has_iabn = False
|
|
||||||
|
|
||||||
|
_NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn}
|
||||||
|
|
||||||
def create_norm_act(layer_type, num_features, jit=False, **kwargs):
|
|
||||||
layer_parts = layer_type.split('_')
|
def get_norm_act_layer(layer_class):
|
||||||
assert len(layer_parts) in (1, 2)
|
layer_class = layer_class.replace('_', '').lower()
|
||||||
layer_class = layer_parts[0].lower()
|
if layer_class.startswith("batchnorm"):
|
||||||
#activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection
|
layer = BatchNormAct2d
|
||||||
|
elif layer_class.startswith("groupnorm"):
|
||||||
if layer_class == "batchnormact":
|
layer = GroupNormAct
|
||||||
layer = BatchNormAct2d(num_features, **kwargs) # defaults to RELU of no kwargs override
|
|
||||||
elif layer_class == "batchnormrelu":
|
|
||||||
assert 'act_layer' not in kwargs
|
|
||||||
layer = BatchNormAct2d(num_features, act_layer=nn.ReLU, **kwargs)
|
|
||||||
elif layer_class == "evonormbatch":
|
elif layer_class == "evonormbatch":
|
||||||
layer = EvoNormBatch2d(num_features, **kwargs)
|
layer = EvoNormBatch2d
|
||||||
elif layer_class == "evonormsample":
|
elif layer_class == "evonormsample":
|
||||||
layer = EvoNormSample2d(num_features, **kwargs)
|
layer = EvoNormSample2d
|
||||||
elif layer_class == "iabn" or layer_class == "inplaceabn":
|
elif layer_class == "iabn" or layer_class == "inplaceabn":
|
||||||
if not has_iabn:
|
layer = InplaceAbn
|
||||||
raise ImportError(
|
|
||||||
"Pplease install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'")
|
|
||||||
layer = InPlaceABN(num_features, **kwargs)
|
|
||||||
else:
|
else:
|
||||||
assert False, "Invalid norm_act layer (%s)" % layer_class
|
assert False, "Invalid norm_act layer (%s)" % layer_class
|
||||||
if jit:
|
|
||||||
layer = torch.jit.script(layer)
|
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs):
|
||||||
|
layer_parts = layer_type.split('-') # e.g. batchnorm-leaky_relu
|
||||||
|
assert len(layer_parts) in (1, 2)
|
||||||
|
layer = get_norm_act_layer(layer_parts[0])
|
||||||
|
#activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection?
|
||||||
|
layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
|
||||||
|
if jit:
|
||||||
|
layer_instance = torch.jit.script(layer_instance)
|
||||||
|
return layer_instance
|
||||||
|
|
||||||
|
|
||||||
|
def convert_norm_act_type(norm_layer, act_layer, norm_kwargs=None):
|
||||||
|
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
|
||||||
|
assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial))
|
||||||
|
norm_act_args = norm_kwargs.copy() if norm_kwargs else {}
|
||||||
|
if isinstance(norm_layer, str):
|
||||||
|
norm_act_layer = get_norm_act_layer(norm_layer)
|
||||||
|
elif norm_layer in _NORM_ACT_TYPES:
|
||||||
|
norm_act_layer = norm_layer
|
||||||
|
elif isinstance(norm_layer, (types.FunctionType, functools.partial)):
|
||||||
|
# assuming this is a lambda/fn/bound partial that creates norm_act layer
|
||||||
|
norm_act_layer = norm_layer
|
||||||
|
else:
|
||||||
|
type_name = norm_layer.__name__.lower()
|
||||||
|
if type_name.startswith('batchnorm'):
|
||||||
|
norm_act_layer = BatchNormAct2d
|
||||||
|
elif type_name.startswith('groupnorm'):
|
||||||
|
norm_act_layer = GroupNormAct
|
||||||
|
else:
|
||||||
|
assert False, f"No equivalent norm_act layer for {type_name}"
|
||||||
|
# Must pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
|
||||||
|
# Newer models will use `apply_act` and likely have `act_layer` arg bound to relevant NormAct types.
|
||||||
|
norm_act_args.update(dict(act_layer=act_layer))
|
||||||
|
return norm_act_layer, norm_act_args
|
||||||
|
@ -0,0 +1,85 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
|
||||||
|
try:
|
||||||
|
from inplace_abn.functions import inplace_abn, inplace_abn_sync
|
||||||
|
has_iabn = True
|
||||||
|
except ImportError:
|
||||||
|
has_iabn = False
|
||||||
|
|
||||||
|
def inplace_abn(x, weight, bias, running_mean, running_var,
|
||||||
|
training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01):
|
||||||
|
raise ImportError(
|
||||||
|
"Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'")
|
||||||
|
|
||||||
|
def inplace_abn_sync(**kwargs):
|
||||||
|
inplace_abn(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class InplaceAbn(nn.Module):
|
||||||
|
"""Activated Batch Normalization
|
||||||
|
|
||||||
|
This gathers a BatchNorm and an activation function in a single module
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
num_features : int
|
||||||
|
Number of feature channels in the input and output.
|
||||||
|
eps : float
|
||||||
|
Small constant to prevent numerical issues.
|
||||||
|
momentum : float
|
||||||
|
Momentum factor applied to compute running statistics.
|
||||||
|
affine : bool
|
||||||
|
If `True` apply learned scale and shift transformation after normalization.
|
||||||
|
act_layer : str or nn.Module type
|
||||||
|
Name or type of the activation functions, one of: `leaky_relu`, `elu`
|
||||||
|
act_param : float
|
||||||
|
Negative slope for the `leaky_relu` activation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True,
|
||||||
|
act_layer="leaky_relu", act_param=0.01, drop_block=None,):
|
||||||
|
super(InplaceAbn, self).__init__()
|
||||||
|
self.num_features = num_features
|
||||||
|
self.affine = affine
|
||||||
|
self.eps = eps
|
||||||
|
self.momentum = momentum
|
||||||
|
if apply_act:
|
||||||
|
if isinstance(act_layer, str):
|
||||||
|
assert act_layer in ('leaky_relu', 'elu', 'identity')
|
||||||
|
self.act_name = act_layer
|
||||||
|
else:
|
||||||
|
# convert act layer passed as type to string
|
||||||
|
if isinstance(act_layer, nn.ELU):
|
||||||
|
self.act_name = 'elu'
|
||||||
|
elif isinstance(act_layer, nn.LeakyReLU):
|
||||||
|
self.act_name = 'leaky_relu'
|
||||||
|
else:
|
||||||
|
assert False, f'Invalid act layer {act_layer.__name__} for IABN'
|
||||||
|
else:
|
||||||
|
self.act_name = 'identity'
|
||||||
|
self.act_param = act_param
|
||||||
|
if self.affine:
|
||||||
|
self.weight = nn.Parameter(torch.ones(num_features))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(num_features))
|
||||||
|
else:
|
||||||
|
self.register_parameter('weight', None)
|
||||||
|
self.register_parameter('bias', None)
|
||||||
|
self.register_buffer('running_mean', torch.zeros(num_features))
|
||||||
|
self.register_buffer('running_var', torch.ones(num_features))
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.constant_(self.running_mean, 0)
|
||||||
|
nn.init.constant_(self.running_var, 1)
|
||||||
|
if self.affine:
|
||||||
|
nn.init.constant_(self.weight, 1)
|
||||||
|
nn.init.constant_(self.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = inplace_abn(
|
||||||
|
x, self.weight, self.bias, self.running_mean, self.running_var,
|
||||||
|
self.training, self.momentum, self.eps, self.act_name, self.act_param)
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
output = output[0]
|
||||||
|
return output
|
@ -0,0 +1,51 @@
|
|||||||
|
from torch import nn as nn
|
||||||
|
|
||||||
|
from .create_conv2d import create_conv2d
|
||||||
|
from .create_norm_act import convert_norm_act_type
|
||||||
|
|
||||||
|
|
||||||
|
class SeparableConvBnAct(nn.Module):
|
||||||
|
""" Separable Conv w/ trailing Norm and Activation
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
|
||||||
|
channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||||
|
act_layer=nn.ReLU, apply_act=True, drop_block=None):
|
||||||
|
super(SeparableConvBnAct, self).__init__()
|
||||||
|
norm_kwargs = norm_kwargs or {}
|
||||||
|
|
||||||
|
self.conv_dw = create_conv2d(
|
||||||
|
in_channels, int(in_channels * channel_multiplier), kernel_size,
|
||||||
|
stride=stride, dilation=dilation, padding=padding, depthwise=True)
|
||||||
|
|
||||||
|
self.conv_pw = create_conv2d(
|
||||||
|
int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
|
||||||
|
|
||||||
|
norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs)
|
||||||
|
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv_dw(x)
|
||||||
|
x = self.conv_pw(x)
|
||||||
|
if self.bn is not None:
|
||||||
|
x = self.bn(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SeparableConv2d(nn.Module):
|
||||||
|
""" Separable Conv
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
|
||||||
|
channel_multiplier=1.0, pw_kernel_size=1):
|
||||||
|
super(SeparableConv2d, self).__init__()
|
||||||
|
|
||||||
|
self.conv_dw = create_conv2d(
|
||||||
|
in_channels, int(in_channels * channel_multiplier), kernel_size,
|
||||||
|
stride=stride, dilation=dilation, padding=padding, depthwise=True)
|
||||||
|
|
||||||
|
self.conv_pw = create_conv2d(
|
||||||
|
int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv_dw(x)
|
||||||
|
x = self.conv_pw(x)
|
||||||
|
return x
|
@ -0,0 +1,408 @@
|
|||||||
|
""" VoVNet (V1 & V2)
|
||||||
|
|
||||||
|
Papers:
|
||||||
|
* `An Energy and GPU-Computation Efficient Backbone Network` - https://arxiv.org/abs/1904.09730
|
||||||
|
* `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
|
||||||
|
|
||||||
|
Looked at https://github.com/youngwanLEE/vovnet-detectron2 &
|
||||||
|
https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py
|
||||||
|
for some reference, rewrote most of the code.
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from .registry import register_model
|
||||||
|
from .helpers import load_pretrained
|
||||||
|
from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, SelectAdaptivePool2d, \
|
||||||
|
create_attn, create_norm_act, get_norm_act_layer
|
||||||
|
|
||||||
|
|
||||||
|
# model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 &
|
||||||
|
# https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py
|
||||||
|
model_cfgs = dict(
|
||||||
|
vovnet39a=dict(
|
||||||
|
stem_ch=[64, 64, 128],
|
||||||
|
stage_conv_ch=[128, 160, 192, 224],
|
||||||
|
stage_out_ch=[256, 512, 768, 1024],
|
||||||
|
layer_per_block=5,
|
||||||
|
block_per_stage=[1, 1, 2, 2],
|
||||||
|
residual=False,
|
||||||
|
depthwise=False,
|
||||||
|
attn='',
|
||||||
|
),
|
||||||
|
vovnet57a=dict(
|
||||||
|
stem_ch=[64, 64, 128],
|
||||||
|
stage_conv_ch=[128, 160, 192, 224],
|
||||||
|
stage_out_ch=[256, 512, 768, 1024],
|
||||||
|
layer_per_block=5,
|
||||||
|
block_per_stage=[1, 1, 4, 3],
|
||||||
|
residual=False,
|
||||||
|
depthwise=False,
|
||||||
|
attn='',
|
||||||
|
|
||||||
|
),
|
||||||
|
ese_vovnet19b_slim_dw=dict(
|
||||||
|
stem_ch=[64, 64, 64],
|
||||||
|
stage_conv_ch=[64, 80, 96, 112],
|
||||||
|
stage_out_ch=[112, 256, 384, 512],
|
||||||
|
layer_per_block=3,
|
||||||
|
block_per_stage=[1, 1, 1, 1],
|
||||||
|
residual=True,
|
||||||
|
depthwise=True,
|
||||||
|
attn='ese',
|
||||||
|
|
||||||
|
),
|
||||||
|
ese_vovnet19b_dw=dict(
|
||||||
|
stem_ch=[64, 64, 64],
|
||||||
|
stage_conv_ch=[128, 160, 192, 224],
|
||||||
|
stage_out_ch=[256, 512, 768, 1024],
|
||||||
|
layer_per_block=3,
|
||||||
|
block_per_stage=[1, 1, 1, 1],
|
||||||
|
residual=True,
|
||||||
|
depthwise=True,
|
||||||
|
attn='ese',
|
||||||
|
),
|
||||||
|
ese_vovnet19b_slim=dict(
|
||||||
|
stem_ch=[64, 64, 128],
|
||||||
|
stage_conv_ch=[64, 80, 96, 112],
|
||||||
|
stage_out_ch=[112, 256, 384, 512],
|
||||||
|
layer_per_block=3,
|
||||||
|
block_per_stage=[1, 1, 1, 1],
|
||||||
|
residual=True,
|
||||||
|
depthwise=False,
|
||||||
|
attn='ese',
|
||||||
|
),
|
||||||
|
ese_vovnet19b=dict(
|
||||||
|
stem_ch=[64, 64, 128],
|
||||||
|
stage_conv_ch=[128, 160, 192, 224],
|
||||||
|
stage_out_ch=[256, 512, 768, 1024],
|
||||||
|
layer_per_block=3,
|
||||||
|
block_per_stage=[1, 1, 1, 1],
|
||||||
|
residual=True,
|
||||||
|
depthwise=False,
|
||||||
|
attn='ese',
|
||||||
|
|
||||||
|
),
|
||||||
|
ese_vovnet39b=dict(
|
||||||
|
stem_ch=[64, 64, 128],
|
||||||
|
stage_conv_ch=[128, 160, 192, 224],
|
||||||
|
stage_out_ch=[256, 512, 768, 1024],
|
||||||
|
layer_per_block=5,
|
||||||
|
block_per_stage=[1, 1, 2, 2],
|
||||||
|
residual=True,
|
||||||
|
depthwise=False,
|
||||||
|
attn='ese',
|
||||||
|
),
|
||||||
|
ese_vovnet57b=dict(
|
||||||
|
stem_ch=[64, 64, 128],
|
||||||
|
stage_conv_ch=[128, 160, 192, 224],
|
||||||
|
stage_out_ch=[256, 512, 768, 1024],
|
||||||
|
layer_per_block=5,
|
||||||
|
block_per_stage=[1, 1, 4, 3],
|
||||||
|
residual=True,
|
||||||
|
depthwise=False,
|
||||||
|
attn='ese',
|
||||||
|
|
||||||
|
),
|
||||||
|
ese_vovnet99b=dict(
|
||||||
|
stem_ch=[64, 64, 128],
|
||||||
|
stage_conv_ch=[128, 160, 192, 224],
|
||||||
|
stage_out_ch=[256, 512, 768, 1024],
|
||||||
|
layer_per_block=5,
|
||||||
|
block_per_stage=[1, 3, 9, 3],
|
||||||
|
residual=True,
|
||||||
|
depthwise=False,
|
||||||
|
attn='ese',
|
||||||
|
),
|
||||||
|
eca_vovnet39b=dict(
|
||||||
|
stem_ch=[64, 64, 128],
|
||||||
|
stage_conv_ch=[128, 160, 192, 224],
|
||||||
|
stage_out_ch=[256, 512, 768, 1024],
|
||||||
|
layer_per_block=5,
|
||||||
|
block_per_stage=[1, 1, 2, 2],
|
||||||
|
residual=True,
|
||||||
|
depthwise=False,
|
||||||
|
attn='eca',
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(url=''):
|
||||||
|
return {
|
||||||
|
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||||
|
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||||
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
|
'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = dict(
|
||||||
|
vovnet39a=_cfg(url=''),
|
||||||
|
vovnet57a=_cfg(url=''),
|
||||||
|
ese_vovnet19b_slim_dw=_cfg(url=''),
|
||||||
|
ese_vovnet19b_dw=_cfg(url=''),
|
||||||
|
ese_vovnet19b_slim=_cfg(url=''),
|
||||||
|
ese_vovnet39b=_cfg(url=''),
|
||||||
|
ese_vovnet57b=_cfg(url=''),
|
||||||
|
ese_vovnet99b=_cfg(url=''),
|
||||||
|
eca_vovnet39b=_cfg(url=''),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SequentialAppendList(nn.Sequential):
|
||||||
|
def __init__(self, *args):
|
||||||
|
super(SequentialAppendList, self).__init__(*args)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor:
|
||||||
|
for i, module in enumerate(self):
|
||||||
|
if i == 0:
|
||||||
|
concat_list.append(module(x))
|
||||||
|
else:
|
||||||
|
concat_list.append(module(concat_list[-1]))
|
||||||
|
x = torch.cat(concat_list, dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class OsaBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_chs, mid_chs, out_chs, layer_per_block, residual=False,
|
||||||
|
depthwise=False, attn='', norm_layer=BatchNormAct2d):
|
||||||
|
super(OsaBlock, self).__init__()
|
||||||
|
|
||||||
|
self.residual = residual
|
||||||
|
self.depthwise = depthwise
|
||||||
|
|
||||||
|
next_in_chs = in_chs
|
||||||
|
if self.depthwise and next_in_chs != mid_chs:
|
||||||
|
assert not residual
|
||||||
|
self.conv_reduction = ConvBnAct(next_in_chs, mid_chs, 1, norm_layer=norm_layer)
|
||||||
|
else:
|
||||||
|
self.conv_reduction = None
|
||||||
|
|
||||||
|
mid_convs = []
|
||||||
|
for i in range(layer_per_block):
|
||||||
|
if self.depthwise:
|
||||||
|
conv = SeparableConvBnAct(mid_chs, mid_chs, norm_layer=norm_layer)
|
||||||
|
else:
|
||||||
|
conv = ConvBnAct(next_in_chs, mid_chs, 3, norm_layer=norm_layer)
|
||||||
|
next_in_chs = mid_chs
|
||||||
|
mid_convs.append(conv)
|
||||||
|
self.conv_mid = SequentialAppendList(*mid_convs)
|
||||||
|
|
||||||
|
# feature aggregation
|
||||||
|
next_in_chs = in_chs + layer_per_block * mid_chs
|
||||||
|
self.conv_concat = ConvBnAct(next_in_chs, out_chs, norm_layer=norm_layer)
|
||||||
|
|
||||||
|
if attn:
|
||||||
|
self.attn = create_attn(attn, out_chs)
|
||||||
|
else:
|
||||||
|
self.attn = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = [x]
|
||||||
|
if self.conv_reduction is not None:
|
||||||
|
x = self.conv_reduction(x)
|
||||||
|
x = self.conv_mid(x, output)
|
||||||
|
x = self.conv_concat(x)
|
||||||
|
if self.attn is not None:
|
||||||
|
x = self.attn(x)
|
||||||
|
if self.residual:
|
||||||
|
x = x + output[0]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class OsaStage(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block,
|
||||||
|
downsample=True, residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d):
|
||||||
|
super(OsaStage, self).__init__()
|
||||||
|
|
||||||
|
if downsample:
|
||||||
|
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
|
||||||
|
else:
|
||||||
|
self.pool = None
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
for i in range(block_per_stage):
|
||||||
|
last_block = i == block_per_stage - 1
|
||||||
|
blocks += [OsaBlock(
|
||||||
|
in_chs if i == 0 else out_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0,
|
||||||
|
depthwise=depthwise, attn=attn if last_block else '', norm_layer=norm_layer)
|
||||||
|
]
|
||||||
|
self.blocks = nn.Sequential(*blocks)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.pool is not None:
|
||||||
|
x = self.pool(x)
|
||||||
|
x = self.blocks(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifierHead(nn.Module):
|
||||||
|
"""Head."""
|
||||||
|
|
||||||
|
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.):
|
||||||
|
super(ClassifierHead, self).__init__()
|
||||||
|
self.drop_rate = drop_rate
|
||||||
|
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
|
||||||
|
if num_classes > 0:
|
||||||
|
self.fc = nn.Linear(in_chs, num_classes, bias=True)
|
||||||
|
else:
|
||||||
|
self.fc = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.global_pool(x).flatten(1)
|
||||||
|
if self.drop_rate:
|
||||||
|
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
|
||||||
|
x = self.fc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VovNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4,
|
||||||
|
norm_layer=BatchNormAct2d):
|
||||||
|
""" VovNet (v2)
|
||||||
|
"""
|
||||||
|
super(VovNet, self).__init__()
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.drop_rate = drop_rate
|
||||||
|
assert stem_stride in (4, 2)
|
||||||
|
|
||||||
|
stem_ch = cfg["stem_ch"]
|
||||||
|
stage_conv_ch = cfg["stage_conv_ch"]
|
||||||
|
stage_out_ch = cfg["stage_out_ch"]
|
||||||
|
block_per_stage = cfg["block_per_stage"]
|
||||||
|
layer_per_block = cfg["layer_per_block"]
|
||||||
|
|
||||||
|
# Stem module
|
||||||
|
last_stem_stride = stem_stride // 2
|
||||||
|
conv_type = SeparableConvBnAct if cfg["depthwise"] else ConvBnAct
|
||||||
|
self.stem = nn.Sequential(*[
|
||||||
|
ConvBnAct(in_chans, stem_ch[0], 3, stride=2, norm_layer=norm_layer),
|
||||||
|
conv_type(stem_ch[0], stem_ch[1], 3, stride=1, norm_layer=norm_layer),
|
||||||
|
conv_type(stem_ch[1], stem_ch[2], 3, stride=last_stem_stride, norm_layer=norm_layer),
|
||||||
|
])
|
||||||
|
|
||||||
|
# OSA stages
|
||||||
|
in_ch_list = stem_ch[-1:] + stage_out_ch[:-1]
|
||||||
|
stage_args = dict(
|
||||||
|
residual=cfg["residual"], depthwise=cfg["depthwise"], attn=cfg["attn"], norm_layer=norm_layer)
|
||||||
|
stages = []
|
||||||
|
for i in range(4): # num_stages
|
||||||
|
downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4
|
||||||
|
stages += [OsaStage(
|
||||||
|
in_ch_list[i], stage_conv_ch[i], stage_out_ch[i], block_per_stage[i], layer_per_block,
|
||||||
|
downsample=downsample, **stage_args)
|
||||||
|
]
|
||||||
|
self.num_features = stage_out_ch[i]
|
||||||
|
self.stages = nn.Sequential(*stages)
|
||||||
|
|
||||||
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
||||||
|
|
||||||
|
for n, m in self.named_modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
nn.init.constant_(m.weight, 1.)
|
||||||
|
nn.init.constant_(m.bias, 0.)
|
||||||
|
|
||||||
|
def get_classifier(self):
|
||||||
|
return self.head.fc
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||||
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||||
|
|
||||||
|
def forward_features(self, x):
|
||||||
|
x = self.stem(x)
|
||||||
|
return self.stages(x)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.forward_features(x)
|
||||||
|
return self.head(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _vovnet(variant, pretrained=False, **kwargs):
|
||||||
|
load_strict = True
|
||||||
|
model_class = VovNet
|
||||||
|
if kwargs.pop('features_only', False):
|
||||||
|
assert False, 'Not Implemented' # TODO
|
||||||
|
load_strict = False
|
||||||
|
kwargs.pop('num_classes', 0)
|
||||||
|
model_cfg = model_cfgs[variant]
|
||||||
|
default_cfg = default_cfgs[variant]
|
||||||
|
model = model_class(model_cfg, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(
|
||||||
|
model, default_cfg,
|
||||||
|
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vovnet39a(pretrained=False, **kwargs):
|
||||||
|
return _vovnet('vovnet39a', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vovnet57a(pretrained=False, **kwargs):
|
||||||
|
return _vovnet('vovnet57a', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def ese_vovnet19b_slim_dw(pretrained=False, **kwargs):
|
||||||
|
return _vovnet('ese_vovnet19b_slim_dw', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def ese_vovnet19b_dw(pretrained=False, **kwargs):
|
||||||
|
return _vovnet('ese_vovnet19b_dw', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def ese_vovnet19b_slim(pretrained=False, **kwargs):
|
||||||
|
return _vovnet('ese_vovnet19b_slim', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def ese_vovnet39b(pretrained=False, **kwargs):
|
||||||
|
return _vovnet('ese_vovnet39b', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def ese_vovnet57b(pretrained=False, **kwargs):
|
||||||
|
return _vovnet('ese_vovnet57b', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def ese_vovnet99b(pretrained=False, **kwargs):
|
||||||
|
return _vovnet('ese_vovnet99b', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def eca_vovnet39b(pretrained=False, **kwargs):
|
||||||
|
return _vovnet('eca_vovnet39b', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Experimental Models
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def ese_vovnet39b_iabn(pretrained=False, **kwargs):
|
||||||
|
norm_layer = get_norm_act_layer('iabn')
|
||||||
|
return _vovnet('ese_vovnet39b', pretrained=pretrained, norm_layer=norm_layer, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def ese_vovnet39b_evos(pretrained=False, **kwargs):
|
||||||
|
def norm_act_fn(num_features, **kwargs):
|
||||||
|
return create_norm_act('EvoNormSample', num_features, jit=False, **kwargs)
|
||||||
|
return _vovnet('ese_vovnet39b', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs)
|
Loading…
Reference in new issue